refactor: apply go fix modernizers for Go 1.26
Automated fixes: interface{} → any, range-over-int, t.Context(),
wg.Go(), strings.SplitSeq, strings.Builder, slices.Contains,
maps helpers, min/max builtins.
Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
parent
8c8b449d66
commit
f75458bce6
40 changed files with 4811 additions and 187 deletions
211
CLAUDE.md
Normal file
211
CLAUDE.md
Normal file
|
|
@ -0,0 +1,211 @@
|
|||
# CLAUDE.md
|
||||
|
||||
## Project Overview
|
||||
|
||||
LEM (Lethean Ethics Model) — training protocol and tooling for ethical alignment of language models via layered curriculum training.
|
||||
|
||||
LEM is the first external consumer of the **Core Go Framework** (`forge.lthn.ai/core/*`). The framework provides Metal inference, grammar scoring, CLI/TUI, lifecycle management, and cross-platform backends. LEM brings the protocol — curriculum, sandwich format, training philosophy — and imports the framework for everything else.
|
||||
|
||||
## Architecture
|
||||
|
||||
### Framework Dependency
|
||||
|
||||
```
|
||||
lthn/LEM (binary — this repo)
|
||||
├── core/go Framework: DI, lifecycle, CLI/TUI, config, process, storage, logging
|
||||
├── core/go-ml Scoring engine, backends, Metal memory management
|
||||
├── core/go-inference Shared TextModel/Backend/Token interfaces (platform-agnostic)
|
||||
├── core/go-mlx Native Metal GPU inference (darwin/arm64, SetMemoryLimit/SetCacheLimit)
|
||||
├── core/go-i18n Grammar v3 scoring engine (reversal)
|
||||
└── core/go-api REST framework (future: LEM Lab API)
|
||||
```
|
||||
|
||||
LEM's own binary, own repo, own identity — but 90% of the logic is supported by the Core Go Framework. The framework was prepared specifically for this phase (14-22 Feb 2026).
|
||||
|
||||
**Cross-platform**: `go-inference` provides shared interfaces that work with both `go-mlx` (Apple Metal, macOS) and `go-rocm` (AMD ROCm, Linux homelab). LEM runs wherever the framework runs.
|
||||
|
||||
**Wiki documentation**: All core repos have wikis at `forge.lthn.ai/core/{repo}.wiki.git` (e.g. `core/go.wiki.git`).
|
||||
|
||||
### Core Go Package Map (`forge.lthn.ai/core/go`)
|
||||
|
||||
| Package | Purpose | LEM Use |
|
||||
|---------|---------|---------|
|
||||
| `pkg/framework/core` | DI container, service lifecycle, message bus | Service orchestration |
|
||||
| `pkg/cli` | CLI framework, command routing, TUI | Commands, Viewport, Spinner, ProgressBar |
|
||||
| `pkg/lab` | LEM Lab monitoring dashboard (collectors, SSE, web UI) | Training progress, benchmarks, golden set stats |
|
||||
| `pkg/process` | Process execution with streaming output | Training subprocess management |
|
||||
| `pkg/config` | Configuration management | `.core/ai/` config hierarchy |
|
||||
| `pkg/log` | Structured logging service | Training logs |
|
||||
| `pkg/io` | Abstract storage (local, S3, SFTP, WebDAV) | Model/adapter storage |
|
||||
| `pkg/workspace` | Encrypted workspace storage | Secure model data |
|
||||
| `pkg/cache` | Caching utilities | Inference caching |
|
||||
| `pkg/store` | Key-value storage | Training state persistence |
|
||||
| `pkg/manifest` | Package manifest signing and verification | Model provenance |
|
||||
| `pkg/plugin` | Plugin installation, loading, versioning | Future: training plugins |
|
||||
| `pkg/ws` | WebSocket hub for real-time streaming | Future: LEM Lab live UI |
|
||||
| `pkg/webview` | Chrome DevTools Protocol client | Future: LEM Lab browser UI |
|
||||
| `pkg/help` | Help/documentation search | CLI help system |
|
||||
| `pkg/ratelimit` | Rate limiting | API rate control |
|
||||
| `pkg/repos` | Git repository registry | Multi-repo management |
|
||||
| `pkg/marketplace` | Plugin/service marketplace | Future: model marketplace |
|
||||
| `pkg/session` | Session management | Training sessions |
|
||||
| `pkg/coredeno` | Deno runtime sidecar integration | Future: scripting |
|
||||
|
||||
### Planned: core/go-lem
|
||||
|
||||
`pkg/lab` (currently in `core/go`) will be extracted to a new `core/go-lem` package. This becomes the LEM protocol layer:
|
||||
- Lab dashboard (collectors, SSE, web UI)
|
||||
- Distill logic (bare probes, sandwich output, grammar gate, best-of-N)
|
||||
- Training types and curriculum definitions
|
||||
- LEM-specific config (`.core/ai/` hierarchy)
|
||||
|
||||
```
|
||||
lthn/LEM (thin binary — wires everything together)
|
||||
├── core/go-lem LEM protocol layer (distill, lab, curriculum)
|
||||
├── core/go-ml Scoring engine, Backend interface
|
||||
├── core/go-mlx Metal GPU
|
||||
├── core/go-i18n Grammar v3
|
||||
└── core/go Framework (CLI/TUI, lifecycle)
|
||||
```
|
||||
|
||||
### Distill Migration: go-inference → go-ml Backend
|
||||
|
||||
LEM's `distill.go` currently imports `go-inference` directly with no Metal memory management. This causes unbounded memory growth. The fix is to migrate to `go-ml`'s `Backend` interface, which wraps `go-inference` with memory controls.
|
||||
|
||||
**Current** (distill.go — broken memory):
|
||||
```go
|
||||
model, err := inference.LoadModel(modelCfg.Paths.Base) // no memory limits
|
||||
for token := range model.Chat(ctx, messages, opts...) { ... } // raw iter.Seq
|
||||
```
|
||||
|
||||
**Target** (following `core ml ab` pattern):
|
||||
```go
|
||||
mlx.SetCacheLimit(cacheGB * 1024 * 1024 * 1024) // e.g. 8GB
|
||||
mlx.SetMemoryLimit(memGB * 1024 * 1024 * 1024) // e.g. 16GB
|
||||
backend, err := ml.NewMLXBackend(modelPath) // wraps go-inference
|
||||
resp, err := backend.Chat(ctx, messages, ml.GenOpts{ // managed inference
|
||||
Temperature: 0.4,
|
||||
MaxTokens: 1024,
|
||||
})
|
||||
runtime.GC() // between probes
|
||||
```
|
||||
|
||||
`ml.NewMLXBackend()` → `inference.LoadModel()` → `InferenceAdapter` (satisfies `ml.Backend` + `ml.StreamingBackend`). Same model, same Metal inference, but with memory limits and GC discipline.
|
||||
|
||||
### core ml train (go-ml, blocked)
|
||||
|
||||
`cmd_train.go` exists in go-ml but is `//go:build ignore` — blocked on go-mlx exporting the concrete model type needed for training (`ApplyLoRA`, `Forward`, `NewCache`, `Tokenizer`). The full loop is written: LoRA, AdamW, VJP, masked cross-entropy loss, Gemma + Qwen3 chat templates. When go-mlx exports the training API, `core ml train` becomes the training backend.
|
||||
|
||||
### Kernel A/B Testing
|
||||
|
||||
The `.txt` kernel was a quick glob/cat of the kernel directory — not scientifically selected. Kernel format must be A/B tested properly.
|
||||
|
||||
**Kernel variants** (in `Axioms-of-Conscious-Systems/kernel/`):
|
||||
- `axioms.json` — Canonical (identical to `lek-1-kernel.json`). 5 axioms with id, name, statement, function, resolution.
|
||||
- `terms.json` — Expands on axioms.json. Precision definitions (consciousness, prime-imperative, reality-anchoring, etc.). Same domain, deeper grind.
|
||||
- `claude-native.json` — Claude's compact interpretation. Core[] array, operational map (fn/when/weight), fast paths (harm→1,3,5; autonomy→4,5; self-doubt→2).
|
||||
- `claude.json` — Agent-specific operational layer extending axioms.json.
|
||||
|
||||
**Test with `core ml ab`** on base (untrained) models:
|
||||
```bash
|
||||
core ml ab --model-path /Volumes/Data/lem/gemma-3-1b-it-base \
|
||||
--kernel axioms=data/kernels/lek-1-kernel.json \
|
||||
--kernel claude-native=/path/to/claude-native.json \
|
||||
--kernel terms=/path/to/terms.json \
|
||||
--cache-limit 8 --mem-limit 16
|
||||
```
|
||||
|
||||
Baseline (no kernel) + each kernel condition → heuristic scores → comparison table with delta per probe. True science, not hunches.
|
||||
|
||||
### Lineage
|
||||
|
||||
`core ml sandwich` pioneered the sandwich generation pattern. `lem distill` borrowed it and added grammar v3 scoring, quality gate, and best-of-N selection. The core framework then matured with proper Metal memory management (`mlx.SetMemoryLimit`, `mlx.SetCacheLimit`), TUI utilities, and lifecycle support. Now LEM imports the full framework stack.
|
||||
|
||||
## Build & Run
|
||||
|
||||
```bash
|
||||
go build -o lem . # Build the lem binary
|
||||
go install . # Install to $GOPATH/bin
|
||||
```
|
||||
|
||||
## Key Commands
|
||||
|
||||
```bash
|
||||
lem distill --model gemma3/1b --probes eval # Distill probes through LEM model (bare probes, sandwich output)
|
||||
lem score --input responses.jsonl # Score with grammar v3
|
||||
lem probe --model gemma3-4b-it # Generate + score probes
|
||||
lem compare --old old.json --new new.json # Compare score files
|
||||
lem export # Export golden set to training JSONL
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
- `.core/ai/ai.yaml` — Global AI config (backend, scorer, generation defaults, distill settings)
|
||||
- `.core/ai/models/gemma3/{size}.yaml` — Per-model config (paths, kernel, lessons, baselines)
|
||||
- `.core/ai/probes.yaml` — Probe sets mapped to curriculum phases
|
||||
|
||||
## Training Curriculum
|
||||
|
||||
| Phase | Probe Set | Format | Description |
|
||||
|-------|-----------|--------|-------------|
|
||||
| 0 | `core` | Sandwich | 101 core probes — LEK axiom absorption |
|
||||
| 1 | `zen` | No LEK | Allen/Watts/composure — philosophical substrate |
|
||||
| 2 | `eval` | Sandwich | 200 expanded probes — deeper alignment |
|
||||
| 3 | `ethics` | Freeflow | 260 adversarial/cultural/sovereignty probes |
|
||||
| 4 | `tension` | Freeflow | Geopolitical multi-perspective scenarios |
|
||||
| 5 | `creative` | Freeflow | Voice and style probes |
|
||||
|
||||
### Sandwich Format
|
||||
|
||||
```
|
||||
[LEK-1 kernel JSON]
|
||||
|
||||
[Probe prompt]
|
||||
|
||||
[LEK-1-Sig quote]
|
||||
```
|
||||
|
||||
Single user message. No system role. Kernel is `data/kernels/lek-1-kernel.json`. Sig is `data/kernels/lek-1-sig.txt`.
|
||||
|
||||
### LEM Models as Distillation Engines
|
||||
|
||||
LEM models (e.g. LEM-Gemma3-1B) have axioms in their weights. When distilling:
|
||||
- **Do NOT** send the kernel in the inference prompt — the model already has it
|
||||
- Model sees bare probes only. Output JSONL gets sandwich wrapping (kernel + probe + sig as user message).
|
||||
- The 1B serves as the lab distillation engine (700MB, runs alongside larger models)
|
||||
|
||||
### Scoring
|
||||
|
||||
- **Grammar v3** (`go-i18n/reversal`) — Primary metric. Composite of tense entropy, vocab richness, question ratio, verb/noun diversity
|
||||
- **Delta mode** — Uplift, echo, enrichment, sycophancy between prompt and response
|
||||
- **Quality gate** — `min_score` in `ai.yaml` (default 40.0), responses below are rejected
|
||||
|
||||
### Data Layout
|
||||
|
||||
```
|
||||
data/
|
||||
kernels/ lek-1-kernel.json, lek-1-sig.txt
|
||||
models/gemma3/ Symlinks to /Volumes/Data/lem/
|
||||
training/
|
||||
lem/
|
||||
ethics/ Core (101), rephrased (404), adversarial, cultural, naive, sovereignty
|
||||
zen/ Golden lessons, seeds, config
|
||||
eval/ test-200.json (P2 candidates)
|
||||
model/gemma3/ Training configs + assembled JSONL per model size
|
||||
pkg/lem/ Go code (distill, scoring, config, export)
|
||||
```
|
||||
|
||||
## Rules
|
||||
|
||||
Read `RULES.md` for the full protocol. Key points:
|
||||
- No Python in production — Go tooling only
|
||||
- Once fused, it stays — verify before merging adapters
|
||||
- LEK must never appear in production chat data
|
||||
- JSON kernel for models (`lek-1-kernel.json` is canonical, `.txt` removed)
|
||||
- Distill and Teach are different operations — never confuse them
|
||||
|
||||
## Coding Standards
|
||||
|
||||
- Go 1.25+, standard library where possible
|
||||
- UK English in comments and docs
|
||||
- Licence: EUPL-1.2
|
||||
|
|
@ -22,9 +22,9 @@ type example struct {
|
|||
|
||||
// composureSource maps filename stems to metadata.
|
||||
var composureSources = map[string]struct {
|
||||
Domain string
|
||||
Author string
|
||||
Work string
|
||||
Domain string
|
||||
Author string
|
||||
Work string
|
||||
Prompts []string
|
||||
}{
|
||||
"consent-wollstonecraft-vindication": {
|
||||
|
|
@ -140,10 +140,7 @@ func main() {
|
|||
promptIdx := 0
|
||||
|
||||
for i := 0; i < len(paragraphs); i += chunkSize {
|
||||
end := i + chunkSize
|
||||
if end > len(paragraphs) {
|
||||
end = len(paragraphs)
|
||||
}
|
||||
end := min(i+chunkSize, len(paragraphs))
|
||||
chunk := strings.Join(paragraphs[i:end], "\n\n")
|
||||
|
||||
// Skip very short chunks.
|
||||
|
|
|
|||
261
cmd/dedup-check/main.go
Normal file
261
cmd/dedup-check/main.go
Normal file
|
|
@ -0,0 +1,261 @@
|
|||
// dedup-check scans JSONL training files for duplicate prompts.
|
||||
// Reports exact matches and near-duplicates across files.
|
||||
package main
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type entry struct {
|
||||
File string
|
||||
Line int
|
||||
SeedID string
|
||||
Voice string
|
||||
Domain string
|
||||
Prompt string
|
||||
}
|
||||
|
||||
func main() {
|
||||
if len(os.Args) < 2 {
|
||||
fmt.Fprintf(os.Stderr, "Usage: dedup-check <dir-or-file> [...]\n")
|
||||
fmt.Fprintf(os.Stderr, "\nScans JSONL/JSON files for duplicate prompts.\n")
|
||||
fmt.Fprintf(os.Stderr, "Reports exact duplicates and shows which files contain them.\n")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
var files []string
|
||||
for _, arg := range os.Args[1:] {
|
||||
info, err := os.Stat(arg)
|
||||
if err != nil {
|
||||
log.Printf("skip %s: %v", arg, err)
|
||||
continue
|
||||
}
|
||||
if info.IsDir() {
|
||||
filepath.Walk(arg, func(path string, fi os.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
if !fi.IsDir() && (strings.HasSuffix(path, ".jsonl") || strings.HasSuffix(path, ".json")) {
|
||||
files = append(files, path)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
} else {
|
||||
files = append(files, arg)
|
||||
}
|
||||
}
|
||||
|
||||
log.Printf("scanning %d files", len(files))
|
||||
|
||||
// Map: normalised prompt → list of entries.
|
||||
exact := make(map[string][]entry)
|
||||
total := 0
|
||||
|
||||
for _, f := range files {
|
||||
entries, err := readEntries(f)
|
||||
if err != nil {
|
||||
log.Printf("skip %s: %v", f, err)
|
||||
continue
|
||||
}
|
||||
for _, e := range entries {
|
||||
key := normalise(e.Prompt)
|
||||
exact[key] = append(exact[key], e)
|
||||
total++
|
||||
}
|
||||
}
|
||||
|
||||
// Report duplicates.
|
||||
dupeGroups := 0
|
||||
dupeEntries := 0
|
||||
crossFile := 0
|
||||
|
||||
for _, entries := range exact {
|
||||
if len(entries) < 2 {
|
||||
continue
|
||||
}
|
||||
dupeGroups++
|
||||
dupeEntries += len(entries)
|
||||
|
||||
// Check if duplicates span multiple files.
|
||||
fileSet := make(map[string]bool)
|
||||
for _, e := range entries {
|
||||
fileSet[e.File] = true
|
||||
}
|
||||
if len(fileSet) > 1 {
|
||||
crossFile++
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Printf("\n=== Dedup Report ===\n")
|
||||
fmt.Printf("Files scanned: %d\n", len(files))
|
||||
fmt.Printf("Total prompts: %d\n", total)
|
||||
fmt.Printf("Unique prompts: %d\n", len(exact))
|
||||
fmt.Printf("Duplicate groups: %d\n", dupeGroups)
|
||||
fmt.Printf("Duplicate entries: %d\n", dupeEntries)
|
||||
fmt.Printf("Cross-file dupes: %d (same prompt in different files)\n", crossFile)
|
||||
|
||||
if crossFile > 0 {
|
||||
fmt.Printf("\n--- Cross-File Duplicates ---\n")
|
||||
shown := 0
|
||||
for prompt, entries := range exact {
|
||||
if len(entries) < 2 {
|
||||
continue
|
||||
}
|
||||
fileSet := make(map[string]bool)
|
||||
for _, e := range entries {
|
||||
fileSet[e.File] = true
|
||||
}
|
||||
if len(fileSet) < 2 {
|
||||
continue
|
||||
}
|
||||
|
||||
shown++
|
||||
if shown > 50 {
|
||||
fmt.Printf("\n... and %d more cross-file groups\n", crossFile-50)
|
||||
break
|
||||
}
|
||||
|
||||
preview := prompt
|
||||
if len(preview) > 100 {
|
||||
preview = preview[:100] + "..."
|
||||
}
|
||||
fmt.Printf("\n[%d] %q\n", shown, preview)
|
||||
for _, e := range entries {
|
||||
seedInfo := ""
|
||||
if e.SeedID != "" {
|
||||
seedInfo = fmt.Sprintf(" seed=%s", e.SeedID)
|
||||
}
|
||||
if e.Voice != "" {
|
||||
seedInfo += fmt.Sprintf(" voice=%s", e.Voice)
|
||||
}
|
||||
fmt.Printf(" %s:%d%s\n", e.File, e.Line, seedInfo)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if dupeGroups > 0 && crossFile == 0 {
|
||||
fmt.Printf("\nAll duplicates are within the same file (no cross-file conflicts).\n")
|
||||
}
|
||||
|
||||
if dupeGroups == 0 {
|
||||
fmt.Printf("\nNo duplicates found.\n")
|
||||
}
|
||||
}
|
||||
|
||||
func readEntries(path string) ([]entry, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
text := strings.TrimSpace(string(data))
|
||||
if text == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Try as JSON array first.
|
||||
if text[0] == '[' {
|
||||
var arr []map[string]any
|
||||
if err := json.Unmarshal(data, &arr); err != nil {
|
||||
return nil, fmt.Errorf("parse JSON array: %w", err)
|
||||
}
|
||||
var entries []entry
|
||||
for i, obj := range arr {
|
||||
prompt := strVal(obj, "prompt")
|
||||
if prompt == "" {
|
||||
// Try messages format.
|
||||
prompt = extractFromMessages(obj)
|
||||
}
|
||||
if prompt == "" {
|
||||
continue
|
||||
}
|
||||
entries = append(entries, entry{
|
||||
File: path,
|
||||
Line: i + 1,
|
||||
SeedID: strVal(obj, "seed_id", "id"),
|
||||
Voice: strVal(obj, "voice"),
|
||||
Domain: strVal(obj, "domain"),
|
||||
Prompt: prompt,
|
||||
})
|
||||
}
|
||||
return entries, nil
|
||||
}
|
||||
|
||||
// JSONL.
|
||||
var entries []entry
|
||||
scanner := bufio.NewScanner(strings.NewReader(text))
|
||||
scanner.Buffer(make([]byte, 4*1024*1024), 4*1024*1024)
|
||||
lineNo := 0
|
||||
for scanner.Scan() {
|
||||
lineNo++
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
var obj map[string]any
|
||||
if err := json.Unmarshal([]byte(line), &obj); err != nil {
|
||||
continue
|
||||
}
|
||||
prompt := strVal(obj, "prompt")
|
||||
if prompt == "" {
|
||||
prompt = extractFromMessages(obj)
|
||||
}
|
||||
if prompt == "" {
|
||||
continue
|
||||
}
|
||||
entries = append(entries, entry{
|
||||
File: path,
|
||||
Line: lineNo,
|
||||
SeedID: strVal(obj, "seed_id", "id"),
|
||||
Voice: strVal(obj, "voice"),
|
||||
Domain: strVal(obj, "domain"),
|
||||
Prompt: prompt,
|
||||
})
|
||||
}
|
||||
return entries, nil
|
||||
}
|
||||
|
||||
// extractFromMessages pulls the user prompt from training format.
|
||||
func extractFromMessages(obj map[string]any) string {
|
||||
msgs, ok := obj["messages"]
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
arr, ok := msgs.([]any)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
for _, m := range arr {
|
||||
msg, ok := m.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if strVal(msg, "role") == "user" {
|
||||
return strVal(msg, "content")
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// strVal extracts a string from a map, trying multiple keys.
|
||||
func strVal(obj map[string]any, keys ...string) string {
|
||||
for _, k := range keys {
|
||||
if v, ok := obj[k]; ok {
|
||||
if s, ok := v.(string); ok {
|
||||
return s
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// normalise strips whitespace for comparison.
|
||||
func normalise(s string) string {
|
||||
return strings.Join(strings.Fields(s), " ")
|
||||
}
|
||||
BIN
datasets/grammar-scores/all.parquet
Normal file
BIN
datasets/grammar-scores/all.parquet
Normal file
Binary file not shown.
BIN
datasets/grammar-scores/base.parquet
Normal file
BIN
datasets/grammar-scores/base.parquet
Normal file
Binary file not shown.
1189
datasets/grammar-scores/grammar-delta-flat.jsonl
Normal file
1189
datasets/grammar-scores/grammar-delta-flat.jsonl
Normal file
File diff suppressed because it is too large
Load diff
1189
datasets/grammar-scores/grammar-delta.jsonl
Normal file
1189
datasets/grammar-scores/grammar-delta.jsonl
Normal file
File diff suppressed because it is too large
Load diff
BIN
datasets/grammar-scores/trained.parquet
Normal file
BIN
datasets/grammar-scores/trained.parquet
Normal file
Binary file not shown.
BIN
dedup-check
Executable file
BIN
dedup-check
Executable file
Binary file not shown.
112
docs/plans/2026-02-22-distill-backend-migration-design.md
Normal file
112
docs/plans/2026-02-22-distill-backend-migration-design.md
Normal file
|
|
@ -0,0 +1,112 @@
|
|||
# Distill Backend Migration Design
|
||||
|
||||
Date: 2026-02-22
|
||||
Status: Approved
|
||||
|
||||
## Problem
|
||||
|
||||
LEM's `distill.go` uses `go-inference.LoadModel()` directly with no Metal memory management. This causes unbounded memory growth (memory pressure red zone on 96GB machine). The core framework's `go-ml` package provides a `Backend` interface with memory controls, proven in `core ml ab`.
|
||||
|
||||
## Solution: Two Tasks with Dependency
|
||||
|
||||
### Task A: go-ml Backend Result Type (upstream, go-ml repo)
|
||||
|
||||
Break the `Backend` interface to return a `Result` struct instead of bare `string`. This gives all consumers access to inference metrics (tok/s, token counts, timing) without reaching behind the abstraction.
|
||||
|
||||
**New type:**
|
||||
```go
|
||||
// inference.go
|
||||
type Result struct {
|
||||
Text string
|
||||
Metrics *inference.GenerateMetrics // nil for backends without metrics
|
||||
}
|
||||
```
|
||||
|
||||
**Interface change:**
|
||||
```go
|
||||
type Backend interface {
|
||||
Generate(ctx context.Context, prompt string, opts GenOpts) (Result, error)
|
||||
Chat(ctx context.Context, messages []Message, opts GenOpts) (Result, error)
|
||||
Name() string
|
||||
Available() bool
|
||||
}
|
||||
```
|
||||
|
||||
**StreamingBackend** unchanged (callback-based, metrics not per-token).
|
||||
|
||||
**Files changed (~50 call sites, all mechanical):**
|
||||
|
||||
| File | Change |
|
||||
|------|--------|
|
||||
| `inference.go` | Add `Result` struct, update `Backend`/`StreamingBackend` interfaces |
|
||||
| `adapter.go` | Return `Result{Text: b.String(), Metrics: a.model.Metrics()}` |
|
||||
| `backend_http.go` | Return `Result{Text: text}` (no metrics) |
|
||||
| `backend_llama.go` | Return `Result{Text: text}` (delegates to http) |
|
||||
| `service.go` | `Generate()` returns `Result` |
|
||||
| `expand.go` | `.Text` access |
|
||||
| `judge.go` | `.Text` access |
|
||||
| `agent_eval.go` | `.Text` access (~3 sites) |
|
||||
| `cmd/cmd_ab.go` | `.Text` + `.Metrics` for tok/s |
|
||||
| `cmd/cmd_sandwich.go` | `.Text` access |
|
||||
| `cmd/cmd_lesson.go` | `.Text` access |
|
||||
| `cmd/cmd_serve.go` | `.Text` access (~2 sites) |
|
||||
| `cmd/cmd_benchmark.go` | `.Text` + `.Metrics` for timing |
|
||||
| `cmd/cmd_sequence.go` | `.Text` access |
|
||||
| `backend_http_textmodel.go` | `.Text` access |
|
||||
| `api/routes.go` | `.Text` access |
|
||||
| Tests (~15 files) | `result` → `result.Text` |
|
||||
|
||||
**Downstream impact:**
|
||||
- `go-ai/mcp/tools_ml.go` — goes through `service.Generate()`, needs `.Text`
|
||||
- LEM — will consume in Task B
|
||||
|
||||
### Task B: LEM distill.go Migration (this repo, after Task A)
|
||||
|
||||
Replace raw `go-inference` with `go-ml` Backend in `distill.go`.
|
||||
|
||||
**Changes:**
|
||||
|
||||
1. **`pkg/lem/distill.go`:**
|
||||
- Replace `inference.LoadModel()` → `ml.NewMLXBackend()`
|
||||
- Replace iter.Seq token loop → `backend.Chat()` returning `Result`
|
||||
- Add `mlx.SetCacheLimit()` / `mlx.SetMemoryLimit()` before model load
|
||||
- Add `runtime.GC()` between probes
|
||||
- Use `result.Metrics` for tok/s logging (replaces `model.Metrics()`)
|
||||
- Add `--cache-limit` and `--mem-limit` flags (defaults: 8GB, 16GB)
|
||||
- Import changes: `go-ml` + `go-mlx` instead of raw `go-inference`
|
||||
|
||||
2. **`pkg/lem/config.go`:**
|
||||
- Add `CacheLimit` / `MemoryLimit` to `AIConfig` (or `DistillConfig`)
|
||||
- Add to `ModelConfig` for per-model override
|
||||
- Update `MergeGenerate` or add `MergeDistill` for memory config merge
|
||||
|
||||
3. **`pkg/lem/backend_metal.go`:**
|
||||
- May need adjustment (currently just `import _ "go-mlx"`)
|
||||
|
||||
4. **`.core/ai/ai.yaml`:**
|
||||
- Add `cache_limit: 8` and `memory_limit: 16` under `distill:` section
|
||||
|
||||
**What stays the same:**
|
||||
- Grammar v3 scoring (`go-i18n/reversal`) — unchanged
|
||||
- Sandwich output format — unchanged
|
||||
- Bare probe inference (model sees probe only) — unchanged
|
||||
- Best-of-N selection — unchanged
|
||||
- Quality gate — unchanged
|
||||
- All probe loading, config merging, output writing — unchanged
|
||||
|
||||
**Reference implementation:** `go-ml/cmd/cmd_ab.go` lines 218-228 (memory setup) + 252-258 (Chat + GC pattern)
|
||||
|
||||
## Execution Order
|
||||
|
||||
1. Agent dispatched to go-ml repo (Task A) — break Backend interface, update all callers
|
||||
2. Build + test go-ml to confirm nothing breaks
|
||||
3. Agent dispatched to LEM repo (Task B) — migrate distill.go, depends on Task A
|
||||
4. Build + test LEM, run `lem distill --dry-run` to verify
|
||||
5. Run actual distill with memory limits, monitor memory pressure
|
||||
|
||||
## Design Decisions
|
||||
|
||||
- **Break the interface** (not add new method): Clean, no dual-API confusion. All callers are internal to the fleet.
|
||||
- **`Result.Metrics` is pointer, nil-safe**: HTTP and llama backends don't have Metal metrics. Callers check `if result.Metrics != nil`.
|
||||
- **Memory defaults 8GB cache / 16GB limit**: Conservative for 1B model on 96GB machine. Flags allow override.
|
||||
- **`runtime.GC()` between probes**: Matches `cmd_ab.go` pattern, prevents incremental memory leak.
|
||||
564
docs/plans/2026-02-22-distill-migration.md
Normal file
564
docs/plans/2026-02-22-distill-migration.md
Normal file
|
|
@ -0,0 +1,564 @@
|
|||
# LEM Distill Backend Migration Implementation Plan
|
||||
|
||||
> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task.
|
||||
|
||||
**Goal:** Replace raw `go-inference` usage in `distill.go` with `go-ml` Backend interface, adding Metal memory management to prevent unbounded memory growth.
|
||||
|
||||
**Architecture:** `distill.go` currently calls `inference.LoadModel()` directly and iterates tokens via `model.Chat()` (iter.Seq). We replace this with `ml.NewMLXBackend()` which wraps the same model in an `InferenceAdapter` providing memory limits (`mlx.SetCacheLimit`/`SetMemoryLimit`), GC discipline between probes, and the new `Result{Text, Metrics}` return type for tok/s logging. The reference implementation is `go-ml/cmd/cmd_ab.go`.
|
||||
|
||||
**Tech Stack:** Go 1.25, `forge.lthn.ai/core/go-ml` (Backend, GenOpts, Result, Message, NewMLXBackend), `forge.lthn.ai/core/go-mlx` (SetCacheLimit, SetMemoryLimit), `forge.lthn.ai/core/go-inference` (GenerateMetrics — via Result.Metrics)
|
||||
|
||||
**Design doc:** `docs/plans/2026-02-22-distill-backend-migration-design.md`
|
||||
|
||||
---
|
||||
|
||||
### Task 1: Add go-ml to go.mod
|
||||
|
||||
`go-ml` is in the `replace` block but not in the `require` block. The compiler will refuse to import it until it's required.
|
||||
|
||||
**Files:**
|
||||
- Modify: `go.mod`
|
||||
|
||||
**Step 1: Add go-ml to require block**
|
||||
|
||||
Add this line to the first `require` block in `go.mod`, between `go-inference` and `go-duckdb`:
|
||||
|
||||
```
|
||||
forge.lthn.ai/core/go-ml v0.0.0-00010101000000-000000000000
|
||||
```
|
||||
|
||||
The version doesn't matter because the `replace` directive overrides it.
|
||||
|
||||
**Step 2: Run go mod tidy**
|
||||
|
||||
Run: `cd /Users/snider/Code/LEM && go mod tidy`
|
||||
|
||||
This will resolve the version and pull in any transitive deps from go-ml.
|
||||
|
||||
**Step 3: Verify build still works**
|
||||
|
||||
Run: `cd /Users/snider/Code/LEM && go build ./...`
|
||||
Expected: Clean build (go-ml is now available but not yet imported)
|
||||
|
||||
**Step 4: Commit**
|
||||
|
||||
```bash
|
||||
cd /Users/snider/Code/LEM
|
||||
git add go.mod go.sum
|
||||
git commit -m "$(cat <<'EOF'
|
||||
chore: add go-ml to go.mod require block
|
||||
|
||||
Prerequisite for distill migration from raw go-inference to
|
||||
go-ml Backend interface with memory management.
|
||||
|
||||
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
|
||||
EOF
|
||||
)"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Task 2: Add memory config fields to DistillConfig
|
||||
|
||||
Add `CacheLimit` and `MemoryLimit` fields to `DistillConfig` in `config.go`, and add corresponding YAML entries to `ai.yaml`.
|
||||
|
||||
**Files:**
|
||||
- Modify: `pkg/lem/config.go:38-42`
|
||||
- Modify: `.core/ai/ai.yaml:27-29`
|
||||
|
||||
**Step 1: Add fields to DistillConfig**
|
||||
|
||||
In `pkg/lem/config.go`, replace the `DistillConfig` struct (lines 39-42):
|
||||
|
||||
```go
|
||||
// DistillConfig holds distillation defaults.
|
||||
type DistillConfig struct {
|
||||
Runs int `yaml:"runs"`
|
||||
MinChars int `yaml:"min_chars"`
|
||||
CacheLimit int `yaml:"cache_limit"` // Metal cache limit in GB (0 = no limit)
|
||||
MemoryLimit int `yaml:"memory_limit"` // Metal memory limit in GB (0 = no limit)
|
||||
}
|
||||
```
|
||||
|
||||
**Step 2: Add YAML entries to ai.yaml**
|
||||
|
||||
In `.core/ai/ai.yaml`, replace the `distill:` block (lines 27-29):
|
||||
|
||||
```yaml
|
||||
# Distillation defaults.
|
||||
distill:
|
||||
runs: 3 # Generations per probe (best kept)
|
||||
min_chars: 20 # Reject responses shorter than this
|
||||
cache_limit: 8 # Metal cache limit in GB (0 = no limit)
|
||||
memory_limit: 16 # Metal memory limit in GB (0 = no limit)
|
||||
```
|
||||
|
||||
**Step 3: Verify build**
|
||||
|
||||
Run: `cd /Users/snider/Code/LEM && go build ./...`
|
||||
Expected: Clean build
|
||||
|
||||
**Step 4: Commit**
|
||||
|
||||
```bash
|
||||
cd /Users/snider/Code/LEM
|
||||
git add pkg/lem/config.go .core/ai/ai.yaml
|
||||
git commit -m "$(cat <<'EOF'
|
||||
feat(distill): add Metal memory limit config fields
|
||||
|
||||
CacheLimit (8GB) and MemoryLimit (16GB) in DistillConfig control
|
||||
mlx.SetCacheLimit/SetMemoryLimit before model load. Conservative
|
||||
defaults for 1B model on 96GB machine.
|
||||
|
||||
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
|
||||
EOF
|
||||
)"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Task 3: Add --cache-limit and --mem-limit flags to RunDistill
|
||||
|
||||
Wire the new config fields into CLI flags so they can be overridden per-run.
|
||||
|
||||
**Files:**
|
||||
- Modify: `pkg/lem/distill.go:38-51` (flag parsing section)
|
||||
|
||||
**Step 1: Add flags after existing flag declarations**
|
||||
|
||||
In `pkg/lem/distill.go`, add these two flags after the `root` flag (after line 47, before `fs.Parse`):
|
||||
|
||||
```go
|
||||
cacheLimit := fs.Int("cache-limit", 0, "Metal cache limit in GB (0 = use ai.yaml default)")
|
||||
memLimit := fs.Int("mem-limit", 0, "Metal memory limit in GB (0 = use ai.yaml default)")
|
||||
```
|
||||
|
||||
**Step 2: Add flag-to-config merge after existing overrides**
|
||||
|
||||
After the `*runs` override block (after line 71), add:
|
||||
|
||||
```go
|
||||
cacheLimitGB := aiCfg.Distill.CacheLimit
|
||||
if *cacheLimit > 0 {
|
||||
cacheLimitGB = *cacheLimit
|
||||
}
|
||||
memLimitGB := aiCfg.Distill.MemoryLimit
|
||||
if *memLimit > 0 {
|
||||
memLimitGB = *memLimit
|
||||
}
|
||||
```
|
||||
|
||||
**Step 3: Add memory limits to dry-run output**
|
||||
|
||||
In the dry-run block, after the `Generate:` line (after line 121), add:
|
||||
|
||||
```go
|
||||
fmt.Printf("Memory: cache=%dGB limit=%dGB\n", cacheLimitGB, memLimitGB)
|
||||
```
|
||||
|
||||
**Step 4: Verify build**
|
||||
|
||||
Run: `cd /Users/snider/Code/LEM && go build ./...`
|
||||
Expected: Clean build (flags are parsed but not yet used for model loading)
|
||||
|
||||
**Step 5: Commit**
|
||||
|
||||
```bash
|
||||
cd /Users/snider/Code/LEM
|
||||
git add pkg/lem/distill.go
|
||||
git commit -m "$(cat <<'EOF'
|
||||
feat(distill): add --cache-limit and --mem-limit flags
|
||||
|
||||
Override ai.yaml memory config per-run. Values in GB.
|
||||
Not yet wired to model loading.
|
||||
|
||||
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
|
||||
EOF
|
||||
)"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Task 4: Replace inference.LoadModel with ml.NewMLXBackend
|
||||
|
||||
The core migration: swap `inference.LoadModel()` + raw iter.Seq for `ml.NewMLXBackend()` + `backend.Chat()`. This is the biggest task.
|
||||
|
||||
**Files:**
|
||||
- Modify: `pkg/lem/distill.go` (imports, model loading, inference loop, metrics)
|
||||
|
||||
**Step 1: Update imports**
|
||||
|
||||
Replace the import block (lines 3-16) with:
|
||||
|
||||
```go
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"forge.lthn.ai/core/go-i18n/reversal"
|
||||
ml "forge.lthn.ai/core/go-ml"
|
||||
"forge.lthn.ai/core/go-mlx"
|
||||
)
|
||||
```
|
||||
|
||||
Key changes:
|
||||
- Remove `"forge.lthn.ai/core/go-inference"`
|
||||
- Add `ml "forge.lthn.ai/core/go-ml"` (named import to avoid collision with the package name)
|
||||
- Add `"forge.lthn.ai/core/go-mlx"` (for `mlx.SetCacheLimit`, `mlx.SetMemoryLimit`)
|
||||
- Add `"runtime"` (for `runtime.GC()`)
|
||||
|
||||
**Step 2: Replace model loading with memory-managed backend**
|
||||
|
||||
Replace the model loading block (lines 138-147):
|
||||
|
||||
```go
|
||||
// Set Metal memory limits before loading model.
|
||||
if cacheLimitGB > 0 {
|
||||
mlx.SetCacheLimit(uint64(cacheLimitGB) * 1024 * 1024 * 1024)
|
||||
log.Printf("metal cache limit: %dGB", cacheLimitGB)
|
||||
}
|
||||
if memLimitGB > 0 {
|
||||
mlx.SetMemoryLimit(uint64(memLimitGB) * 1024 * 1024 * 1024)
|
||||
log.Printf("metal memory limit: %dGB", memLimitGB)
|
||||
}
|
||||
|
||||
// Load model via go-ml Backend (wraps go-inference with memory management).
|
||||
log.Printf("loading model: %s", modelCfg.Paths.Base)
|
||||
backend, err := ml.NewMLXBackend(modelCfg.Paths.Base)
|
||||
if err != nil {
|
||||
log.Fatalf("load model: %v", err)
|
||||
}
|
||||
defer backend.Close()
|
||||
|
||||
log.Printf("model loaded via %s backend", backend.Name())
|
||||
```
|
||||
|
||||
Note: `backend.Close()` replaces `model.Close()`. We lose `model.Info()` for the architecture log line — that's fine, `NewMLXBackend` already logs arch/layers/quant via slog.
|
||||
|
||||
**Step 3: Build GenOpts from merged config**
|
||||
|
||||
Add this after the model loading block, before the tokeniser init (before the `tok := reversal.NewTokeniser()` line):
|
||||
|
||||
```go
|
||||
// Build generation options from merged config.
|
||||
genOpts := ml.GenOpts{
|
||||
MaxTokens: genCfg.MaxTokens,
|
||||
Temperature: genCfg.Temperature,
|
||||
TopP: genCfg.TopP,
|
||||
TopK: genCfg.TopK,
|
||||
RepeatPenalty: genCfg.RepeatPenalty,
|
||||
}
|
||||
```
|
||||
|
||||
**Step 4: Replace the inference loop**
|
||||
|
||||
Replace the inner inference block (lines 178-201):
|
||||
|
||||
Old code (lines 178-201):
|
||||
```go
|
||||
// Inference uses bare probe — the model generates from its weights.
|
||||
// Sandwich wrapping is only for the training output format.
|
||||
messages := []inference.Message{
|
||||
{Role: "user", Content: probe.Prompt},
|
||||
}
|
||||
|
||||
// Generate via native Metal inference.
|
||||
start := time.Now()
|
||||
var sb strings.Builder
|
||||
for token := range model.Chat(ctx, messages,
|
||||
inference.WithMaxTokens(genCfg.MaxTokens),
|
||||
inference.WithTemperature(float32(genCfg.Temperature)),
|
||||
inference.WithTopP(float32(genCfg.TopP)),
|
||||
inference.WithTopK(genCfg.TopK),
|
||||
inference.WithRepeatPenalty(float32(genCfg.RepeatPenalty)),
|
||||
) {
|
||||
sb.WriteString(token.Text)
|
||||
}
|
||||
if err := model.Err(); err != nil {
|
||||
fmt.Fprintf(os.Stderr, " → ERROR: %v\n", err)
|
||||
continue
|
||||
}
|
||||
response := sb.String()
|
||||
elapsed := time.Since(start)
|
||||
```
|
||||
|
||||
New code:
|
||||
```go
|
||||
// Inference uses bare probe — the model generates from its weights.
|
||||
// Sandwich wrapping is only for the training output format.
|
||||
messages := []ml.Message{
|
||||
{Role: "user", Content: probe.Prompt},
|
||||
}
|
||||
|
||||
// Generate via go-ml Backend (memory-managed Metal inference).
|
||||
start := time.Now()
|
||||
result, err := backend.Chat(ctx, messages, genOpts)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, " → ERROR: %v\n", err)
|
||||
continue
|
||||
}
|
||||
response := result.Text
|
||||
elapsed := time.Since(start)
|
||||
```
|
||||
|
||||
**Step 5: Replace metrics access**
|
||||
|
||||
Replace the metrics line (line 214):
|
||||
|
||||
Old:
|
||||
```go
|
||||
met := model.Metrics()
|
||||
fmt.Fprintf(os.Stderr, " → %d chars, g=%.1f up=%+.1f echo=%.2f enr=%+.1f, %.1fs (%.0f tok/s)\n",
|
||||
len(response), grammar.Composite,
|
||||
delta.Uplift, delta.Echo, delta.Enrichment,
|
||||
elapsed.Seconds(), met.DecodeTokensPerSec)
|
||||
```
|
||||
|
||||
New:
|
||||
```go
|
||||
tokPerSec := 0.0
|
||||
if result.Metrics != nil {
|
||||
tokPerSec = result.Metrics.DecodeTokensPerSec
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, " → %d chars, g=%.1f up=%+.1f echo=%.2f enr=%+.1f, %.1fs (%.0f tok/s)\n",
|
||||
len(response), grammar.Composite,
|
||||
delta.Uplift, delta.Echo, delta.Enrichment,
|
||||
elapsed.Seconds(), tokPerSec)
|
||||
```
|
||||
|
||||
**Step 6: Add runtime.GC() after each probe**
|
||||
|
||||
After the quality gate block's closing brace (after line 257 — the closing `}` of the `if best != nil` / `else` block), add:
|
||||
|
||||
```go
|
||||
|
||||
// Release GPU memory between probes to prevent incremental leak.
|
||||
runtime.GC()
|
||||
```
|
||||
|
||||
**Step 7: Update the summary footer**
|
||||
|
||||
Replace the model info line in the summary (line 263):
|
||||
|
||||
Old:
|
||||
```go
|
||||
fmt.Fprintf(os.Stderr, "Model: %s (%s)\n", modelCfg.Name, info.Architecture)
|
||||
```
|
||||
|
||||
New:
|
||||
```go
|
||||
fmt.Fprintf(os.Stderr, "Model: %s (%s)\n", modelCfg.Name, backend.Name())
|
||||
```
|
||||
|
||||
**Step 8: Verify build**
|
||||
|
||||
Run: `cd /Users/snider/Code/LEM && go build ./...`
|
||||
Expected: Clean build. No remaining references to `go-inference` in distill.go.
|
||||
|
||||
**Step 9: Verify no stale inference imports**
|
||||
|
||||
Run: `grep -n 'go-inference' /Users/snider/Code/LEM/pkg/lem/distill.go`
|
||||
Expected: No output (import fully removed)
|
||||
|
||||
**Step 10: Commit**
|
||||
|
||||
```bash
|
||||
cd /Users/snider/Code/LEM
|
||||
git add pkg/lem/distill.go
|
||||
git commit -m "$(cat <<'EOF'
|
||||
feat(distill): migrate from go-inference to go-ml Backend
|
||||
|
||||
Replace inference.LoadModel() with ml.NewMLXBackend() which wraps
|
||||
the same Metal model with memory management (SetCacheLimit,
|
||||
SetMemoryLimit). Replace raw iter.Seq token loop with backend.Chat()
|
||||
returning Result{Text, Metrics}. Add runtime.GC() between probes
|
||||
to prevent incremental memory leak.
|
||||
|
||||
Reference: go-ml/cmd/cmd_ab.go memory management pattern.
|
||||
|
||||
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
|
||||
EOF
|
||||
)"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Task 5: Update backend_metal.go
|
||||
|
||||
`backend_metal.go` currently blank-imports `go-mlx` to register the Metal backend. Since `ml.NewMLXBackend()` (which we now call from distill.go) already does this import via `go-ml/backend_mlx.go`, the LEM-side blank import may be redundant. However, keep it for safety — it ensures the Metal backend is registered even if distill.go isn't the only consumer.
|
||||
|
||||
**Files:**
|
||||
- Modify: `pkg/lem/backend_metal.go`
|
||||
|
||||
**Step 1: Verify the file is still needed**
|
||||
|
||||
Read `pkg/lem/backend_metal.go`. It should contain:
|
||||
```go
|
||||
//go:build darwin && arm64
|
||||
|
||||
package lem
|
||||
|
||||
import _ "forge.lthn.ai/core/go-mlx"
|
||||
```
|
||||
|
||||
This is still valid. `go-mlx` registers itself via `init()`, and `ml.NewMLXBackend()` also imports it. The double import is harmless (Go deduplicates). No change needed here — leave as-is.
|
||||
|
||||
**Step 2: Verify build on darwin/arm64**
|
||||
|
||||
Run: `cd /Users/snider/Code/LEM && go build ./...`
|
||||
Expected: Clean build
|
||||
|
||||
No commit needed — no changes.
|
||||
|
||||
---
|
||||
|
||||
### Task 6: Run go mod tidy and verify
|
||||
|
||||
After all code changes, clean up the dependency graph.
|
||||
|
||||
**Files:**
|
||||
- Modify: `go.mod`, `go.sum`
|
||||
|
||||
**Step 1: Run go mod tidy**
|
||||
|
||||
Run: `cd /Users/snider/Code/LEM && go mod tidy`
|
||||
|
||||
This may remove `go-inference` from the direct require block if distill.go was the only direct consumer. Check: `backend_metal.go` imports `go-mlx` (not go-inference), and no other `.go` files in `pkg/lem/` import go-inference directly.
|
||||
|
||||
**Step 2: Check if go-inference moved to indirect**
|
||||
|
||||
Run: `grep 'go-inference' /Users/snider/Code/LEM/go.mod`
|
||||
|
||||
Expected: Either removed entirely (if go-ml pulls it transitively) or moved to `// indirect`. Either is correct.
|
||||
|
||||
**Step 3: Full build**
|
||||
|
||||
Run: `cd /Users/snider/Code/LEM && go build ./...`
|
||||
Expected: Clean build
|
||||
|
||||
**Step 4: Run go vet**
|
||||
|
||||
Run: `cd /Users/snider/Code/LEM && go vet ./...`
|
||||
Expected: Clean (no issues)
|
||||
|
||||
**Step 5: Commit if go.mod/go.sum changed**
|
||||
|
||||
```bash
|
||||
cd /Users/snider/Code/LEM
|
||||
git add go.mod go.sum
|
||||
git commit -m "$(cat <<'EOF'
|
||||
chore: go mod tidy after distill migration
|
||||
|
||||
go-inference moves to indirect (pulled transitively via go-ml).
|
||||
go-ml is now a direct dependency.
|
||||
|
||||
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
|
||||
EOF
|
||||
)"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Task 7: Smoke test with --dry-run
|
||||
|
||||
Verify the full flag pipeline works end-to-end without loading a model.
|
||||
|
||||
**Files:** None (test only)
|
||||
|
||||
**Step 1: Build the lem binary**
|
||||
|
||||
Run: `cd /Users/snider/Code/LEM && go build -o lem .`
|
||||
Expected: Binary built successfully
|
||||
|
||||
**Step 2: Run dry-run**
|
||||
|
||||
Run: `cd /Users/snider/Code/LEM && ./lem distill --model gemma3/1b --probes core --dry-run`
|
||||
|
||||
Expected output (approximate):
|
||||
```
|
||||
Model: gemma-3-1b-it (path...)
|
||||
Backend: metal
|
||||
Probes: 101
|
||||
Runs: 3 per probe (303 total generations)
|
||||
Gate: grammar v3 composite >= 40.0
|
||||
Generate: temp=0.80 max_tokens=4096 top_p=0.95
|
||||
Memory: cache=8GB limit=16GB
|
||||
Output: (path to lesson file)
|
||||
|
||||
core-001: ...
|
||||
core-002: ...
|
||||
... and 91 more
|
||||
```
|
||||
|
||||
Key checks:
|
||||
- `Memory:` line appears with values from ai.yaml (8/16)
|
||||
- No crash, no import errors
|
||||
|
||||
**Step 3: Test flag override**
|
||||
|
||||
Run: `cd /Users/snider/Code/LEM && ./lem distill --model gemma3/1b --probes core --dry-run --cache-limit 4 --mem-limit 8`
|
||||
|
||||
Expected: `Memory: cache=4GB limit=8GB` (flag overrides config)
|
||||
|
||||
No commit needed — test only.
|
||||
|
||||
---
|
||||
|
||||
### Task 8: Live inference test (optional, requires GPU)
|
||||
|
||||
Only run this if on a machine with the model downloaded and Metal GPU available.
|
||||
|
||||
**Files:** None (test only)
|
||||
|
||||
**Step 1: Run a single probe with memory limits**
|
||||
|
||||
Run:
|
||||
```bash
|
||||
cd /Users/snider/Code/LEM
|
||||
./lem distill --model gemma3/1b --probes core --runs 1 --cache-limit 8 --mem-limit 16 2>&1 | head -30
|
||||
```
|
||||
|
||||
Expected:
|
||||
- Model loads with memory limit logs
|
||||
- First probe generates, shows tok/s
|
||||
- No memory pressure red zone
|
||||
- `runtime.GC()` runs between probes (no visible output, but memory stays bounded)
|
||||
|
||||
**Step 2: Monitor memory**
|
||||
|
||||
In a separate terminal: `watch -n1 'sysctl hw.memsize; vm_stat | head -5'`
|
||||
|
||||
Or check Activity Monitor → Memory Pressure. Should stay green/yellow, not red.
|
||||
|
||||
No commit needed — test only.
|
||||
|
||||
---
|
||||
|
||||
## Summary of Changes
|
||||
|
||||
| File | Change |
|
||||
|------|--------|
|
||||
| `go.mod` | Add `go-ml` to require, `go-inference` moves to indirect |
|
||||
| `go.sum` | Updated transitively |
|
||||
| `pkg/lem/config.go:39-42` | Add `CacheLimit`, `MemoryLimit` to `DistillConfig` |
|
||||
| `.core/ai/ai.yaml:27-29` | Add `cache_limit: 8`, `memory_limit: 16` |
|
||||
| `pkg/lem/distill.go` | Full migration: imports, model loading, inference loop, metrics, GC |
|
||||
| `pkg/lem/backend_metal.go` | No change (blank import still valid) |
|
||||
|
||||
## What Stays the Same
|
||||
|
||||
- Grammar v3 scoring (`go-i18n/reversal`) — unchanged
|
||||
- Sandwich output format — unchanged
|
||||
- Bare probe inference (model sees probe only) — unchanged
|
||||
- Best-of-N selection — unchanged
|
||||
- Quality gate — unchanged
|
||||
- All probe loading, config merging, output writing — unchanged
|
||||
- `main.go` routing — unchanged
|
||||
|
|
@ -44,11 +44,11 @@ type checkpoint struct {
|
|||
|
||||
// probeResult holds the result of running all probes against a checkpoint.
|
||||
type probeResult struct {
|
||||
Accuracy float64 `json:"accuracy"`
|
||||
Correct int `json:"correct"`
|
||||
Total int `json:"total"`
|
||||
ByCategory map[string]categoryResult `json:"by_category"`
|
||||
Probes map[string]singleProbeResult `json:"probes"`
|
||||
Accuracy float64 `json:"accuracy"`
|
||||
Correct int `json:"correct"`
|
||||
Total int `json:"total"`
|
||||
ByCategory map[string]categoryResult `json:"by_category"`
|
||||
Probes map[string]singleProbeResult `json:"probes"`
|
||||
}
|
||||
|
||||
type categoryResult struct {
|
||||
|
|
@ -176,7 +176,7 @@ func discoverCheckpoints(cfg *agentConfig) ([]checkpoint, error) {
|
|||
var checkpoints []checkpoint
|
||||
iterRe := regexp.MustCompile(`(\d+)`)
|
||||
|
||||
for _, dirpath := range strings.Split(strings.TrimSpace(out), "\n") {
|
||||
for dirpath := range strings.SplitSeq(strings.TrimSpace(out), "\n") {
|
||||
if dirpath == "" {
|
||||
continue
|
||||
}
|
||||
|
|
@ -188,7 +188,7 @@ func discoverCheckpoints(cfg *agentConfig) ([]checkpoint, error) {
|
|||
continue
|
||||
}
|
||||
|
||||
for _, filepath := range strings.Split(strings.TrimSpace(filesOut), "\n") {
|
||||
for filepath := range strings.SplitSeq(strings.TrimSpace(filesOut), "\n") {
|
||||
if filepath == "" {
|
||||
continue
|
||||
}
|
||||
|
|
@ -507,7 +507,7 @@ func replayInfluxBuffer(workDir string, influx *InfluxClient) {
|
|||
}
|
||||
|
||||
var remaining []string
|
||||
for _, line := range strings.Split(strings.TrimSpace(string(data)), "\n") {
|
||||
for line := range strings.SplitSeq(strings.TrimSpace(string(data)), "\n") {
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
|
|
|
|||
|
|
@ -44,7 +44,7 @@ func TestConvertToConversations(t *testing.T) {
|
|||
{Prompt: "What is ethics?", Response: strings.Repeat("a", 100)},
|
||||
{Prompt: "Short", Response: "tiny"}, // Too short.
|
||||
{Prompt: "Error", Response: "ERROR: something"}, // Error prefix.
|
||||
{Prompt: "Empty", Response: ""}, // Empty.
|
||||
{Prompt: "Empty", Response: ""}, // Empty.
|
||||
{Prompt: "Good one", Response: strings.Repeat("b", 200)},
|
||||
}
|
||||
|
||||
|
|
@ -204,12 +204,12 @@ func TestOutputFormatCompatibility(t *testing.T) {
|
|||
}
|
||||
|
||||
// Parse back as generic map to check structure.
|
||||
var m map[string]interface{}
|
||||
var m map[string]any
|
||||
if err := json.Unmarshal(data, &m); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
messages, ok := m["messages"].([]interface{})
|
||||
messages, ok := m["messages"].([]any)
|
||||
if !ok {
|
||||
t.Fatal("expected messages array")
|
||||
}
|
||||
|
|
@ -217,7 +217,7 @@ func TestOutputFormatCompatibility(t *testing.T) {
|
|||
t.Fatalf("expected 2 messages, got %d", len(messages))
|
||||
}
|
||||
|
||||
msg0 := messages[0].(map[string]interface{})
|
||||
msg0 := messages[0].(map[string]any)
|
||||
if msg0["role"] != "user" || msg0["content"] != "prompt" {
|
||||
t.Errorf("unexpected first message: %v", msg0)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -125,8 +125,8 @@ func transposeFloat32(data []byte, rows, cols int) []byte {
|
|||
}
|
||||
|
||||
result := make([]byte, len(data))
|
||||
for r := 0; r < rows; r++ {
|
||||
for c := 0; c < cols; c++ {
|
||||
for r := range rows {
|
||||
for c := range cols {
|
||||
srcOff := (r*cols + c) * 4
|
||||
dstOff := (c*rows + r) * 4
|
||||
copy(result[dstOff:dstOff+4], data[srcOff:srcOff+4])
|
||||
|
|
@ -142,8 +142,8 @@ func transposeFloat16(data []byte, rows, cols int) []byte {
|
|||
}
|
||||
|
||||
result := make([]byte, len(data))
|
||||
for r := 0; r < rows; r++ {
|
||||
for c := 0; c < cols; c++ {
|
||||
for r := range rows {
|
||||
for c := range cols {
|
||||
srcOff := (r*cols + c) * 2
|
||||
dstOff := (c*rows + r) * 2
|
||||
copy(result[dstOff:dstOff+2], data[srcOff:srcOff+2])
|
||||
|
|
@ -178,7 +178,7 @@ func writeSafetensors(path string, tensors map[string]safetensorsTensorInfo, ten
|
|||
}
|
||||
|
||||
// Build header JSON.
|
||||
headerMap := make(map[string]interface{})
|
||||
headerMap := make(map[string]any)
|
||||
for k, info := range updatedTensors {
|
||||
headerMap[k] = info
|
||||
}
|
||||
|
|
@ -314,23 +314,23 @@ func convertMLXtoPEFT(safetensorsPath, configPath, outputDir, baseModelName stri
|
|||
sort.Ints(sortedLayers)
|
||||
|
||||
// Write PEFT adapter_config.json.
|
||||
peftConfig := map[string]interface{}{
|
||||
"auto_mapping": nil,
|
||||
peftConfig := map[string]any{
|
||||
"auto_mapping": nil,
|
||||
"base_model_name_or_path": baseModelName,
|
||||
"bias": "none",
|
||||
"fan_in_fan_out": false,
|
||||
"inference_mode": true,
|
||||
"init_lora_weights": true,
|
||||
"layers_pattern": nil,
|
||||
"layers_to_transform": sortedLayers,
|
||||
"lora_alpha": math.Round(scale * float64(rank)),
|
||||
"lora_dropout": mlxConfig.LoraParameters.Dropout,
|
||||
"modules_to_save": nil,
|
||||
"peft_type": "LORA",
|
||||
"r": rank,
|
||||
"revision": nil,
|
||||
"target_modules": sortedModules,
|
||||
"task_type": "CAUSAL_LM",
|
||||
"bias": "none",
|
||||
"fan_in_fan_out": false,
|
||||
"inference_mode": true,
|
||||
"init_lora_weights": true,
|
||||
"layers_pattern": nil,
|
||||
"layers_to_transform": sortedLayers,
|
||||
"lora_alpha": math.Round(scale * float64(rank)),
|
||||
"lora_dropout": mlxConfig.LoraParameters.Dropout,
|
||||
"modules_to_save": nil,
|
||||
"peft_type": "LORA",
|
||||
"r": rank,
|
||||
"revision": nil,
|
||||
"target_modules": sortedModules,
|
||||
"task_type": "CAUSAL_LM",
|
||||
}
|
||||
|
||||
cfgJSON, err := json.MarshalIndent(peftConfig, "", " ")
|
||||
|
|
|
|||
|
|
@ -67,11 +67,11 @@ func TestConvertMLXtoPEFT(t *testing.T) {
|
|||
|
||||
// Create tensor data: 4x2=8 floats and 2x4=8 floats.
|
||||
loraAData := make([]byte, 4*2*4)
|
||||
for i := 0; i < 8; i++ {
|
||||
for i := range 8 {
|
||||
binary.LittleEndian.PutUint32(loraAData[i*4:], math.Float32bits(float32(i+1)))
|
||||
}
|
||||
loraBData := make([]byte, 2*4*4)
|
||||
for i := 0; i < 8; i++ {
|
||||
for i := range 8 {
|
||||
binary.LittleEndian.PutUint32(loraBData[i*4:], math.Float32bits(float32(10+i)))
|
||||
}
|
||||
|
||||
|
|
@ -85,8 +85,8 @@ func TestConvertMLXtoPEFT(t *testing.T) {
|
|||
}
|
||||
|
||||
// Create MLX config.
|
||||
mlxConfig := map[string]interface{}{
|
||||
"lora_parameters": map[string]interface{}{
|
||||
mlxConfig := map[string]any{
|
||||
"lora_parameters": map[string]any{
|
||||
"rank": 8,
|
||||
"scale": 20.0,
|
||||
"dropout": 0.0,
|
||||
|
|
@ -116,7 +116,7 @@ func TestConvertMLXtoPEFT(t *testing.T) {
|
|||
t.Fatalf("read peft config: %v", err)
|
||||
}
|
||||
|
||||
var peftConfig map[string]interface{}
|
||||
var peftConfig map[string]any
|
||||
if err := json.Unmarshal(peftCfgData, &peftConfig); err != nil {
|
||||
t.Fatalf("parse peft config: %v", err)
|
||||
}
|
||||
|
|
@ -166,7 +166,7 @@ func TestReadWriteSafetensorsRoundtrip(t *testing.T) {
|
|||
data := map[string][]byte{
|
||||
"weight_a": make([]byte, 2*3*4),
|
||||
}
|
||||
for i := 0; i < 6; i++ {
|
||||
for i := range 6 {
|
||||
binary.LittleEndian.PutUint32(data["weight_a"][i*4:], math.Float32bits(float32(i)))
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -83,10 +83,7 @@ func RunCoverage(args []string) {
|
|||
|
||||
avg := float64(total) / float64(len(regionRows))
|
||||
for _, r := range regionRows {
|
||||
barLen := int(float64(r.n) / avg * 10)
|
||||
if barLen > 40 {
|
||||
barLen = 40
|
||||
}
|
||||
barLen := min(int(float64(r.n)/avg*10), 40)
|
||||
bar := strings.Repeat("#", barLen)
|
||||
gap := ""
|
||||
if float64(r.n) < avg*0.5 {
|
||||
|
|
|
|||
|
|
@ -110,7 +110,7 @@ func (db *DB) CountGoldenSet() (int, error) {
|
|||
func (db *DB) QueryExpansionPrompts(status string, limit int) ([]ExpansionPromptRow, error) {
|
||||
query := "SELECT idx, seed_id, region, domain, language, prompt, prompt_en, priority, status " +
|
||||
"FROM expansion_prompts"
|
||||
var args []interface{}
|
||||
var args []any
|
||||
|
||||
if status != "" {
|
||||
query += " WHERE status = ?"
|
||||
|
|
@ -163,7 +163,7 @@ func (db *DB) UpdateExpansionStatus(idx int64, status string) error {
|
|||
}
|
||||
|
||||
// QueryRows executes an arbitrary SQL query and returns results as maps.
|
||||
func (db *DB) QueryRows(query string, args ...interface{}) ([]map[string]interface{}, error) {
|
||||
func (db *DB) QueryRows(query string, args ...any) ([]map[string]any, error) {
|
||||
rows, err := db.conn.Query(query, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query: %w", err)
|
||||
|
|
@ -175,17 +175,17 @@ func (db *DB) QueryRows(query string, args ...interface{}) ([]map[string]interfa
|
|||
return nil, fmt.Errorf("columns: %w", err)
|
||||
}
|
||||
|
||||
var result []map[string]interface{}
|
||||
var result []map[string]any
|
||||
for rows.Next() {
|
||||
values := make([]interface{}, len(cols))
|
||||
ptrs := make([]interface{}, len(cols))
|
||||
values := make([]any, len(cols))
|
||||
ptrs := make([]any, len(cols))
|
||||
for i := range values {
|
||||
ptrs[i] = &values[i]
|
||||
}
|
||||
if err := rows.Scan(ptrs...); err != nil {
|
||||
return nil, fmt.Errorf("scan: %w", err)
|
||||
}
|
||||
row := make(map[string]interface{}, len(cols))
|
||||
row := make(map[string]any, len(cols))
|
||||
for i, col := range cols {
|
||||
row[col] = values[i]
|
||||
}
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ func NewEngine(judge *Judge, concurrency int, suiteList string) *Engine {
|
|||
suites["standard"] = true
|
||||
suites["exact"] = true
|
||||
} else {
|
||||
for _, s := range strings.Split(suiteList, ",") {
|
||||
for s := range strings.SplitSeq(suiteList, ",") {
|
||||
s = strings.TrimSpace(s)
|
||||
if s != "" {
|
||||
suites[s] = true
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ import (
|
|||
// apiResponse is returned for /v1/chat/completions requests.
|
||||
// influxQueryHandler handles /api/v3/query_sql requests.
|
||||
// influxWriteHandler handles /api/v3/write_lp requests.
|
||||
func mockExpandServer(t *testing.T, apiResponse string, influxQueryHandler func(q string) ([]map[string]interface{}, int), influxWriteHandler func(body string)) *httptest.Server {
|
||||
func mockExpandServer(t *testing.T, apiResponse string, influxQueryHandler func(q string) ([]map[string]any, int), influxWriteHandler func(body string)) *httptest.Server {
|
||||
t.Helper()
|
||||
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch {
|
||||
|
|
@ -61,11 +61,11 @@ func mockExpandServer(t *testing.T, apiResponse string, influxQueryHandler func(
|
|||
|
||||
func TestGetCompletedIDs(t *testing.T) {
|
||||
t.Run("returns completed IDs from InfluxDB", func(t *testing.T) {
|
||||
server := mockInfluxServer(t, func(q string) ([]map[string]interface{}, int) {
|
||||
server := mockInfluxServer(t, func(q string) ([]map[string]any, int) {
|
||||
if !strings.Contains(q, "expansion_gen") {
|
||||
t.Errorf("expected query against expansion_gen, got: %s", q)
|
||||
}
|
||||
return []map[string]interface{}{
|
||||
return []map[string]any{
|
||||
{"seed_id": "prompt_001"},
|
||||
{"seed_id": "prompt_002"},
|
||||
{"seed_id": "prompt_003"},
|
||||
|
|
@ -92,8 +92,8 @@ func TestGetCompletedIDs(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("returns empty set when no completed IDs", func(t *testing.T) {
|
||||
server := mockInfluxServer(t, func(q string) ([]map[string]interface{}, int) {
|
||||
return []map[string]interface{}{}, http.StatusOK
|
||||
server := mockInfluxServer(t, func(q string) ([]map[string]any, int) {
|
||||
return []map[string]any{}, http.StatusOK
|
||||
})
|
||||
defer server.Close()
|
||||
|
||||
|
|
@ -133,9 +133,9 @@ func TestExpandPromptsBasic(t *testing.T) {
|
|||
|
||||
server := mockExpandServer(t,
|
||||
"This is a generated response about ethics and sovereignty.",
|
||||
func(q string) ([]map[string]interface{}, int) {
|
||||
func(q string) ([]map[string]any, int) {
|
||||
// No completed IDs
|
||||
return []map[string]interface{}{}, http.StatusOK
|
||||
return []map[string]any{}, http.StatusOK
|
||||
},
|
||||
func(body string) {
|
||||
writtenLines = append(writtenLines, body)
|
||||
|
|
@ -229,8 +229,8 @@ func TestExpandPromptsSkipsCompleted(t *testing.T) {
|
|||
defer apiServer.Close()
|
||||
|
||||
// InfluxDB returns p1 and p2 as already completed.
|
||||
influxServer := mockInfluxServer(t, func(q string) ([]map[string]interface{}, int) {
|
||||
return []map[string]interface{}{
|
||||
influxServer := mockInfluxServer(t, func(q string) ([]map[string]any, int) {
|
||||
return []map[string]any{
|
||||
{"seed_id": "p1"},
|
||||
{"seed_id": "p2"},
|
||||
}, http.StatusOK
|
||||
|
|
@ -292,8 +292,8 @@ func TestExpandPromptsAllCompleted(t *testing.T) {
|
|||
defer apiServer.Close()
|
||||
|
||||
// All prompts already completed.
|
||||
influxServer := mockInfluxServer(t, func(q string) ([]map[string]interface{}, int) {
|
||||
return []map[string]interface{}{
|
||||
influxServer := mockInfluxServer(t, func(q string) ([]map[string]any, int) {
|
||||
return []map[string]any{
|
||||
{"seed_id": "p1"},
|
||||
{"seed_id": "p2"},
|
||||
}, http.StatusOK
|
||||
|
|
@ -338,8 +338,8 @@ func TestExpandPromptsDryRun(t *testing.T) {
|
|||
}))
|
||||
defer apiServer.Close()
|
||||
|
||||
influxServer := mockInfluxServer(t, func(q string) ([]map[string]interface{}, int) {
|
||||
return []map[string]interface{}{
|
||||
influxServer := mockInfluxServer(t, func(q string) ([]map[string]any, int) {
|
||||
return []map[string]any{
|
||||
{"seed_id": "p1"},
|
||||
}, http.StatusOK
|
||||
})
|
||||
|
|
@ -398,8 +398,8 @@ func TestExpandPromptsAPIErrorSkipsPrompt(t *testing.T) {
|
|||
}))
|
||||
defer apiServer.Close()
|
||||
|
||||
influxServer := mockInfluxServer(t, func(q string) ([]map[string]interface{}, int) {
|
||||
return []map[string]interface{}{}, http.StatusOK
|
||||
influxServer := mockInfluxServer(t, func(q string) ([]map[string]any, int) {
|
||||
return []map[string]any{}, http.StatusOK
|
||||
})
|
||||
defer influxServer.Close()
|
||||
|
||||
|
|
@ -462,7 +462,7 @@ func TestExpandPromptsInfluxWriteErrorNonFatal(t *testing.T) {
|
|||
switch r.URL.Path {
|
||||
case "/api/v3/query_sql":
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode([]map[string]interface{}{})
|
||||
json.NewEncoder(w).Encode([]map[string]any{})
|
||||
case "/api/v3/write_lp":
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
w.Write([]byte("write failed"))
|
||||
|
|
@ -516,8 +516,8 @@ func TestExpandPromptsOutputJSONLStructure(t *testing.T) {
|
|||
}))
|
||||
defer apiServer.Close()
|
||||
|
||||
influxServer := mockInfluxServer(t, func(q string) ([]map[string]interface{}, int) {
|
||||
return []map[string]interface{}{}, http.StatusOK
|
||||
influxServer := mockInfluxServer(t, func(q string) ([]map[string]any, int) {
|
||||
return []map[string]any{}, http.StatusOK
|
||||
})
|
||||
defer influxServer.Close()
|
||||
|
||||
|
|
@ -589,7 +589,7 @@ func TestExpandPromptsInfluxLineProtocol(t *testing.T) {
|
|||
switch r.URL.Path {
|
||||
case "/api/v3/query_sql":
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode([]map[string]interface{}{})
|
||||
json.NewEncoder(w).Encode([]map[string]any{})
|
||||
case "/api/v3/write_lp":
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
writtenBodies = append(writtenBodies, string(body))
|
||||
|
|
@ -658,8 +658,8 @@ func TestExpandPromptsAppendMode(t *testing.T) {
|
|||
}))
|
||||
defer apiServer.Close()
|
||||
|
||||
influxServer := mockInfluxServer(t, func(q string) ([]map[string]interface{}, int) {
|
||||
return []map[string]interface{}{}, http.StatusOK
|
||||
influxServer := mockInfluxServer(t, func(q string) ([]map[string]any, int) {
|
||||
return []map[string]any{}, http.StatusOK
|
||||
})
|
||||
defer influxServer.Close()
|
||||
|
||||
|
|
@ -735,8 +735,8 @@ func TestExpandPromptsLimit(t *testing.T) {
|
|||
}))
|
||||
defer apiServer.Close()
|
||||
|
||||
influxServer := mockInfluxServer(t, func(q string) ([]map[string]interface{}, int) {
|
||||
return []map[string]interface{}{}, http.StatusOK
|
||||
influxServer := mockInfluxServer(t, func(q string) ([]map[string]any, int) {
|
||||
return []map[string]any{}, http.StatusOK
|
||||
})
|
||||
defer influxServer.Close()
|
||||
|
||||
|
|
@ -813,8 +813,8 @@ func TestExpandPromptsLimitAfterFiltering(t *testing.T) {
|
|||
defer apiServer.Close()
|
||||
|
||||
// p1 and p2 are already completed.
|
||||
influxServer := mockInfluxServer(t, func(q string) ([]map[string]interface{}, int) {
|
||||
return []map[string]interface{}{
|
||||
influxServer := mockInfluxServer(t, func(q string) ([]map[string]any, int) {
|
||||
return []map[string]any{
|
||||
{"seed_id": "p1"},
|
||||
{"seed_id": "p2"},
|
||||
}, http.StatusOK
|
||||
|
|
@ -855,8 +855,8 @@ func TestExpandPromptsLimitAfterFiltering(t *testing.T) {
|
|||
t.Fatalf("read output file: %v", err)
|
||||
}
|
||||
|
||||
lines := strings.Split(strings.TrimSpace(string(data)), "\n")
|
||||
for _, line := range lines {
|
||||
lines := strings.SplitSeq(strings.TrimSpace(string(data)), "\n")
|
||||
for line := range lines {
|
||||
var r Response
|
||||
if err := json.Unmarshal([]byte(line), &r); err != nil {
|
||||
t.Fatalf("parse output line: %v", err)
|
||||
|
|
@ -893,8 +893,8 @@ func TestExpandPromptsLimitZeroMeansAll(t *testing.T) {
|
|||
}))
|
||||
defer apiServer.Close()
|
||||
|
||||
influxServer := mockInfluxServer(t, func(q string) ([]map[string]interface{}, int) {
|
||||
return []map[string]interface{}{}, http.StatusOK
|
||||
influxServer := mockInfluxServer(t, func(q string) ([]map[string]any, int) {
|
||||
return []map[string]any{}, http.StatusOK
|
||||
})
|
||||
defer influxServer.Close()
|
||||
|
||||
|
|
@ -943,8 +943,8 @@ func TestExpandPromptsOutputHasCharsField(t *testing.T) {
|
|||
}))
|
||||
defer apiServer.Close()
|
||||
|
||||
influxServer := mockInfluxServer(t, func(q string) ([]map[string]interface{}, int) {
|
||||
return []map[string]interface{}{}, http.StatusOK
|
||||
influxServer := mockInfluxServer(t, func(q string) ([]map[string]any, int) {
|
||||
return []map[string]any{}, http.StatusOK
|
||||
})
|
||||
defer influxServer.Close()
|
||||
|
||||
|
|
@ -972,7 +972,7 @@ func TestExpandPromptsOutputHasCharsField(t *testing.T) {
|
|||
}
|
||||
|
||||
// Parse as raw JSON to check for the chars field.
|
||||
var raw map[string]interface{}
|
||||
var raw map[string]any
|
||||
if err := json.Unmarshal([]byte(strings.TrimSpace(string(data))), &raw); err != nil {
|
||||
t.Fatalf("parse raw JSON: %v", err)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -90,13 +90,13 @@ func TestFilterResponses(t *testing.T) {
|
|||
func TestSplitData(t *testing.T) {
|
||||
// Create 100 responses for easy percentage calculation.
|
||||
responses := make([]Response, 100)
|
||||
for i := 0; i < 100; i++ {
|
||||
for i := range 100 {
|
||||
responses[i] = Response{ID: "r" + string(rune('0'+i/10)) + string(rune('0'+i%10))}
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
trainPct, validPct, testPct int
|
||||
name string
|
||||
trainPct, validPct, testPct int
|
||||
wantTrain, wantValid, wantTest int
|
||||
}{
|
||||
{
|
||||
|
|
@ -411,7 +411,7 @@ func TestExportPercentageValidation(t *testing.T) {
|
|||
|
||||
// Helper functions.
|
||||
|
||||
func mustJSON(t *testing.T, v interface{}) string {
|
||||
func mustJSON(t *testing.T, v any) string {
|
||||
t.Helper()
|
||||
data, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
|
|
|
|||
|
|
@ -317,7 +317,7 @@ func importBenchmarkFile(db *DB, path, source string) int {
|
|||
scanner.Buffer(make([]byte, 1024*1024), 1024*1024)
|
||||
|
||||
for scanner.Scan() {
|
||||
var rec map[string]interface{}
|
||||
var rec map[string]any
|
||||
if err := json.Unmarshal(scanner.Bytes(), &rec); err != nil {
|
||||
continue
|
||||
}
|
||||
|
|
@ -349,7 +349,7 @@ func importBenchmarkQuestions(db *DB, path, benchmark string) int {
|
|||
scanner.Buffer(make([]byte, 1024*1024), 1024*1024)
|
||||
|
||||
for scanner.Scan() {
|
||||
var rec map[string]interface{}
|
||||
var rec map[string]any
|
||||
if err := json.Unmarshal(scanner.Bytes(), &rec); err != nil {
|
||||
continue
|
||||
}
|
||||
|
|
@ -387,26 +387,26 @@ func importSeeds(db *DB, seedDir string) int {
|
|||
region := strings.TrimSuffix(filepath.Base(path), ".json")
|
||||
|
||||
// Try parsing as array or object with prompts/seeds field.
|
||||
var seedsList []interface{}
|
||||
var raw interface{}
|
||||
var seedsList []any
|
||||
var raw any
|
||||
if err := json.Unmarshal(data, &raw); err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch v := raw.(type) {
|
||||
case []interface{}:
|
||||
case []any:
|
||||
seedsList = v
|
||||
case map[string]interface{}:
|
||||
if prompts, ok := v["prompts"].([]interface{}); ok {
|
||||
case map[string]any:
|
||||
if prompts, ok := v["prompts"].([]any); ok {
|
||||
seedsList = prompts
|
||||
} else if seeds, ok := v["seeds"].([]interface{}); ok {
|
||||
} else if seeds, ok := v["seeds"].([]any); ok {
|
||||
seedsList = seeds
|
||||
}
|
||||
}
|
||||
|
||||
for _, s := range seedsList {
|
||||
switch seed := s.(type) {
|
||||
case map[string]interface{}:
|
||||
case map[string]any:
|
||||
prompt := strOrEmpty(seed, "prompt")
|
||||
if prompt == "" {
|
||||
prompt = strOrEmpty(seed, "text")
|
||||
|
|
@ -432,14 +432,14 @@ func importSeeds(db *DB, seedDir string) int {
|
|||
return count
|
||||
}
|
||||
|
||||
func strOrEmpty(m map[string]interface{}, key string) string {
|
||||
func strOrEmpty(m map[string]any, key string) string {
|
||||
if v, ok := m[key]; ok {
|
||||
return fmt.Sprintf("%v", v)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func floatOrZero(m map[string]interface{}, key string) float64 {
|
||||
func floatOrZero(m map[string]any, key string) float64 {
|
||||
if v, ok := m[key]; ok {
|
||||
if f, ok := v.(float64); ok {
|
||||
return f
|
||||
|
|
|
|||
|
|
@ -80,7 +80,7 @@ func (c *InfluxClient) WriteLp(lines []string) error {
|
|||
|
||||
// QuerySQL runs a SQL query against InfluxDB and returns the result rows.
|
||||
// POST to /api/v3/query_sql with JSON body {"db": db, "q": sql}.
|
||||
func (c *InfluxClient) QuerySQL(sql string) ([]map[string]interface{}, error) {
|
||||
func (c *InfluxClient) QuerySQL(sql string) ([]map[string]any, error) {
|
||||
reqBody := map[string]string{
|
||||
"db": c.db,
|
||||
"q": sql,
|
||||
|
|
@ -116,7 +116,7 @@ func (c *InfluxClient) QuerySQL(sql string) ([]map[string]interface{}, error) {
|
|||
return nil, fmt.Errorf("query failed %d: %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
var rows []map[string]interface{}
|
||||
var rows []map[string]any
|
||||
if err := json.Unmarshal(respBody, &rows); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal query response: %w", err)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -158,7 +158,7 @@ func TestQuerySQL(t *testing.T) {
|
|||
var gotContentType string
|
||||
var gotPath string
|
||||
|
||||
responseData := []map[string]interface{}{
|
||||
responseData := []map[string]any{
|
||||
{"id": "row1", "score": float64(7.5)},
|
||||
{"id": "row2", "score": float64(8.2)},
|
||||
}
|
||||
|
|
|
|||
|
|
@ -100,7 +100,7 @@ type contentScoreEntry struct {
|
|||
}
|
||||
|
||||
type contentProbeEntry struct {
|
||||
Scores map[string]interface{} `json:"scores"`
|
||||
Scores map[string]any `json:"scores"`
|
||||
}
|
||||
|
||||
// ingestContentScores reads a content scores JSONL file and writes
|
||||
|
|
@ -316,7 +316,7 @@ func ingestTrainingCurve(influx *InfluxClient, filepath, model, runID string, ba
|
|||
}
|
||||
|
||||
// toFloat64 converts an interface{} to float64 if possible.
|
||||
func toFloat64(v interface{}) (float64, bool) {
|
||||
func toFloat64(v any) (float64, bool) {
|
||||
switch n := v.(type) {
|
||||
case float64:
|
||||
return n, true
|
||||
|
|
|
|||
|
|
@ -50,7 +50,7 @@ func TestIngestContentScores(t *testing.T) {
|
|||
Label: "gemma12b@200",
|
||||
Aggregates: map[string]float64{"sovereignty": 7.5, "ethical_depth": 8.0},
|
||||
Probes: map[string]contentProbeEntry{
|
||||
"p01": {Scores: map[string]interface{}{"sovereignty": 8.0, "notes": "good"}},
|
||||
"p01": {Scores: map[string]any{"sovereignty": 8.0, "notes": "good"}},
|
||||
},
|
||||
},
|
||||
{
|
||||
|
|
@ -201,7 +201,7 @@ Iter 50: Val loss 1.523
|
|||
|
||||
func TestToFloat64(t *testing.T) {
|
||||
tests := []struct {
|
||||
input interface{}
|
||||
input any
|
||||
want float64
|
||||
ok bool
|
||||
}{
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ import (
|
|||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// RunInventory is the CLI entry point for the inventory command.
|
||||
|
|
@ -86,12 +87,12 @@ func RunInventory(args []string) {
|
|||
}
|
||||
|
||||
func joinStrings(parts []string, sep string) string {
|
||||
result := ""
|
||||
var result strings.Builder
|
||||
for i, p := range parts {
|
||||
if i > 0 {
|
||||
result += sep
|
||||
result.WriteString(sep)
|
||||
}
|
||||
result += p
|
||||
result.WriteString(p)
|
||||
}
|
||||
return result
|
||||
return result.String()
|
||||
}
|
||||
|
|
|
|||
|
|
@ -60,11 +60,11 @@ func RunQuery(args []string) {
|
|||
log.Fatalf("columns: %v", err)
|
||||
}
|
||||
|
||||
var results []map[string]interface{}
|
||||
var results []map[string]any
|
||||
|
||||
for rows.Next() {
|
||||
values := make([]interface{}, len(cols))
|
||||
ptrs := make([]interface{}, len(cols))
|
||||
values := make([]any, len(cols))
|
||||
ptrs := make([]any, len(cols))
|
||||
for i := range values {
|
||||
ptrs[i] = &values[i]
|
||||
}
|
||||
|
|
@ -73,7 +73,7 @@ func RunQuery(args []string) {
|
|||
log.Fatalf("scan: %v", err)
|
||||
}
|
||||
|
||||
row := make(map[string]interface{})
|
||||
row := make(map[string]any)
|
||||
for i, col := range cols {
|
||||
v := values[i]
|
||||
// Convert []byte to string for readability.
|
||||
|
|
|
|||
|
|
@ -170,7 +170,7 @@ func printStatus(influx *InfluxClient, w io.Writer) error {
|
|||
|
||||
// dedupeTraining merges training status and loss rows, keeping only the first
|
||||
// (latest) row per model. Returns sorted by model name.
|
||||
func dedupeTraining(statusRows, lossRows []map[string]interface{}) []trainingRow {
|
||||
func dedupeTraining(statusRows, lossRows []map[string]any) []trainingRow {
|
||||
// Build loss lookup: model -> loss value.
|
||||
lossMap := make(map[string]float64)
|
||||
lossSeenMap := make(map[string]bool)
|
||||
|
|
@ -225,7 +225,7 @@ func dedupeTraining(statusRows, lossRows []map[string]interface{}) []trainingRow
|
|||
|
||||
// dedupeGeneration deduplicates generation progress rows by worker, keeping
|
||||
// only the first (latest) row per worker. Returns sorted by worker name.
|
||||
func dedupeGeneration(rows []map[string]interface{}) []genRow {
|
||||
func dedupeGeneration(rows []map[string]any) []genRow {
|
||||
seen := make(map[string]bool)
|
||||
var result []genRow
|
||||
for _, row := range rows {
|
||||
|
|
@ -255,7 +255,7 @@ func dedupeGeneration(rows []map[string]interface{}) []genRow {
|
|||
|
||||
// strVal extracts a string value from a row map, returning "" if missing or
|
||||
// not a string.
|
||||
func strVal(row map[string]interface{}, key string) string {
|
||||
func strVal(row map[string]any, key string) string {
|
||||
v, ok := row[key]
|
||||
if !ok {
|
||||
return ""
|
||||
|
|
@ -269,7 +269,7 @@ func strVal(row map[string]interface{}, key string) string {
|
|||
|
||||
// floatVal extracts a float64 value from a row map, returning 0 if missing or
|
||||
// not a float64.
|
||||
func floatVal(row map[string]interface{}, key string) float64 {
|
||||
func floatVal(row map[string]any, key string) float64 {
|
||||
v, ok := row[key]
|
||||
if !ok {
|
||||
return 0
|
||||
|
|
@ -283,6 +283,6 @@ func floatVal(row map[string]interface{}, key string) float64 {
|
|||
|
||||
// intVal extracts an integer value from a row map. InfluxDB JSON returns all
|
||||
// numbers as float64, so this truncates to int.
|
||||
func intVal(row map[string]interface{}, key string) int {
|
||||
func intVal(row map[string]any, key string) int {
|
||||
return int(floatVal(row, key))
|
||||
}
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ import (
|
|||
// mockInfluxServer creates an httptest server that routes /api/v3/query_sql
|
||||
// requests to the given handler function. The handler receives the parsed
|
||||
// query body and writes the JSON response.
|
||||
func mockInfluxServer(t *testing.T, handler func(q string) ([]map[string]interface{}, int)) *httptest.Server {
|
||||
func mockInfluxServer(t *testing.T, handler func(q string) ([]map[string]any, int)) *httptest.Server {
|
||||
t.Helper()
|
||||
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/api/v3/query_sql" {
|
||||
|
|
@ -37,25 +37,25 @@ func mockInfluxServer(t *testing.T, handler func(q string) ([]map[string]interfa
|
|||
}
|
||||
|
||||
func TestPrintStatusFullOutput(t *testing.T) {
|
||||
server := mockInfluxServer(t, func(q string) ([]map[string]interface{}, int) {
|
||||
server := mockInfluxServer(t, func(q string) ([]map[string]any, int) {
|
||||
switch {
|
||||
case strings.Contains(q, "training_status"):
|
||||
return []map[string]interface{}{
|
||||
return []map[string]any{
|
||||
{"model": "gemma-3-1b", "run_id": "run1", "status": "complete", "iteration": float64(1000), "total_iters": float64(1000), "pct": float64(100.0)},
|
||||
{"model": "gemma-3-12b", "run_id": "run2", "status": "training", "iteration": float64(340), "total_iters": float64(600), "pct": float64(56.7)},
|
||||
{"model": "gemma-3-27b", "run_id": "run3", "status": "pending", "iteration": float64(0), "total_iters": float64(400), "pct": float64(0.0)},
|
||||
}, http.StatusOK
|
||||
case strings.Contains(q, "training_loss"):
|
||||
return []map[string]interface{}{
|
||||
return []map[string]any{
|
||||
{"model": "gemma-3-1b", "loss_type": "train", "loss": float64(1.434), "iteration": float64(1000), "tokens_per_sec": float64(512.3)},
|
||||
{"model": "gemma-3-12b", "loss_type": "train", "loss": float64(0.735), "iteration": float64(340), "tokens_per_sec": float64(128.5)},
|
||||
}, http.StatusOK
|
||||
case strings.Contains(q, "golden_gen_progress"):
|
||||
return []map[string]interface{}{
|
||||
return []map[string]any{
|
||||
{"worker": "m3-gpu0", "completed": float64(15000), "target": float64(15000), "pct": float64(100.0)},
|
||||
}, http.StatusOK
|
||||
case strings.Contains(q, "expansion_progress"):
|
||||
return []map[string]interface{}{
|
||||
return []map[string]any{
|
||||
{"worker": "m3-gpu0", "completed": float64(0), "target": float64(46331), "pct": float64(0.0)},
|
||||
}, http.StatusOK
|
||||
default:
|
||||
|
|
@ -147,8 +147,8 @@ func TestPrintStatusFullOutput(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestPrintStatusEmptyResults(t *testing.T) {
|
||||
server := mockInfluxServer(t, func(q string) ([]map[string]interface{}, int) {
|
||||
return []map[string]interface{}{}, http.StatusOK
|
||||
server := mockInfluxServer(t, func(q string) ([]map[string]any, int) {
|
||||
return []map[string]any{}, http.StatusOK
|
||||
})
|
||||
defer server.Close()
|
||||
|
||||
|
|
@ -178,27 +178,27 @@ func TestPrintStatusEmptyResults(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestPrintStatusDeduplicatesModels(t *testing.T) {
|
||||
server := mockInfluxServer(t, func(q string) ([]map[string]interface{}, int) {
|
||||
server := mockInfluxServer(t, func(q string) ([]map[string]any, int) {
|
||||
switch {
|
||||
case strings.Contains(q, "training_status"):
|
||||
// Two rows for same model — first should win (latest by time desc).
|
||||
return []map[string]interface{}{
|
||||
return []map[string]any{
|
||||
{"model": "gemma-3-1b", "run_id": "run2", "status": "training", "iteration": float64(500), "total_iters": float64(1000), "pct": float64(50.0)},
|
||||
{"model": "gemma-3-1b", "run_id": "run1", "status": "complete", "iteration": float64(1000), "total_iters": float64(1000), "pct": float64(100.0)},
|
||||
}, http.StatusOK
|
||||
case strings.Contains(q, "training_loss"):
|
||||
// Two rows for same model — first should win.
|
||||
return []map[string]interface{}{
|
||||
return []map[string]any{
|
||||
{"model": "gemma-3-1b", "loss_type": "train", "loss": float64(0.8), "iteration": float64(500), "tokens_per_sec": float64(256.0)},
|
||||
{"model": "gemma-3-1b", "loss_type": "train", "loss": float64(1.5), "iteration": float64(200), "tokens_per_sec": float64(200.0)},
|
||||
}, http.StatusOK
|
||||
case strings.Contains(q, "golden_gen_progress"):
|
||||
return []map[string]interface{}{
|
||||
return []map[string]any{
|
||||
{"worker": "m3-gpu0", "completed": float64(5000), "target": float64(15000), "pct": float64(33.3)},
|
||||
{"worker": "m3-gpu0", "completed": float64(3000), "target": float64(15000), "pct": float64(20.0)},
|
||||
}, http.StatusOK
|
||||
case strings.Contains(q, "expansion_progress"):
|
||||
return []map[string]interface{}{}, http.StatusOK
|
||||
return []map[string]any{}, http.StatusOK
|
||||
default:
|
||||
return nil, http.StatusOK
|
||||
}
|
||||
|
|
@ -240,20 +240,20 @@ func TestPrintStatusDeduplicatesModels(t *testing.T) {
|
|||
|
||||
func TestPrintStatusPartialData(t *testing.T) {
|
||||
// Training status exists but no loss data.
|
||||
server := mockInfluxServer(t, func(q string) ([]map[string]interface{}, int) {
|
||||
server := mockInfluxServer(t, func(q string) ([]map[string]any, int) {
|
||||
switch {
|
||||
case strings.Contains(q, "training_status"):
|
||||
return []map[string]interface{}{
|
||||
return []map[string]any{
|
||||
{"model": "gemma-3-4b", "run_id": "run1", "status": "training", "iteration": float64(100), "total_iters": float64(500), "pct": float64(20.0)},
|
||||
}, http.StatusOK
|
||||
case strings.Contains(q, "training_loss"):
|
||||
return []map[string]interface{}{}, http.StatusOK
|
||||
return []map[string]any{}, http.StatusOK
|
||||
case strings.Contains(q, "golden_gen_progress"):
|
||||
return []map[string]interface{}{
|
||||
return []map[string]any{
|
||||
{"worker": "m3-gpu1", "completed": float64(7000), "target": float64(15000), "pct": float64(46.7)},
|
||||
}, http.StatusOK
|
||||
case strings.Contains(q, "expansion_progress"):
|
||||
return []map[string]interface{}{}, http.StatusOK
|
||||
return []map[string]any{}, http.StatusOK
|
||||
default:
|
||||
return nil, http.StatusOK
|
||||
}
|
||||
|
|
@ -279,8 +279,8 @@ func TestPrintStatusPartialData(t *testing.T) {
|
|||
t.Error("output missing 100/500")
|
||||
}
|
||||
// Should NOT have a "loss=" for this model since no loss data.
|
||||
lines := strings.Split(output, "\n")
|
||||
for _, line := range lines {
|
||||
lines := strings.SplitSeq(output, "\n")
|
||||
for line := range lines {
|
||||
if strings.Contains(line, "gemma-3-4b") && strings.Contains(line, "loss=") {
|
||||
t.Error("gemma-3-4b should not show loss when no loss data exists")
|
||||
}
|
||||
|
|
@ -296,26 +296,26 @@ func TestPrintStatusPartialData(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestPrintStatusMultipleModels(t *testing.T) {
|
||||
server := mockInfluxServer(t, func(q string) ([]map[string]interface{}, int) {
|
||||
server := mockInfluxServer(t, func(q string) ([]map[string]any, int) {
|
||||
switch {
|
||||
case strings.Contains(q, "training_status"):
|
||||
return []map[string]interface{}{
|
||||
return []map[string]any{
|
||||
{"model": "gemma-3-1b", "run_id": "r1", "status": "complete", "iteration": float64(1000), "total_iters": float64(1000), "pct": float64(100.0)},
|
||||
{"model": "gemma-3-4b", "run_id": "r2", "status": "training", "iteration": float64(250), "total_iters": float64(500), "pct": float64(50.0)},
|
||||
{"model": "gemma-3-12b", "run_id": "r3", "status": "pending", "iteration": float64(0), "total_iters": float64(600), "pct": float64(0.0)},
|
||||
{"model": "gemma-3-27b", "run_id": "r4", "status": "queued", "iteration": float64(0), "total_iters": float64(400), "pct": float64(0.0)},
|
||||
}, http.StatusOK
|
||||
case strings.Contains(q, "training_loss"):
|
||||
return []map[string]interface{}{
|
||||
return []map[string]any{
|
||||
{"model": "gemma-3-1b", "loss_type": "train", "loss": float64(1.2), "iteration": float64(1000), "tokens_per_sec": float64(500.0)},
|
||||
{"model": "gemma-3-4b", "loss_type": "train", "loss": float64(2.1), "iteration": float64(250), "tokens_per_sec": float64(300.0)},
|
||||
}, http.StatusOK
|
||||
case strings.Contains(q, "golden_gen_progress"):
|
||||
return []map[string]interface{}{
|
||||
return []map[string]any{
|
||||
{"worker": "m3-gpu0", "completed": float64(15000), "target": float64(15000), "pct": float64(100.0)},
|
||||
}, http.StatusOK
|
||||
case strings.Contains(q, "expansion_progress"):
|
||||
return []map[string]interface{}{
|
||||
return []map[string]any{
|
||||
{"worker": "m3-gpu1", "completed": float64(10000), "target": float64(46331), "pct": float64(21.6)},
|
||||
}, http.StatusOK
|
||||
default:
|
||||
|
|
@ -379,20 +379,20 @@ func TestPrintStatusQueryErrorGraceful(t *testing.T) {
|
|||
|
||||
func TestPrintStatusModelOrdering(t *testing.T) {
|
||||
// Models should appear in a deterministic order (sorted by name).
|
||||
server := mockInfluxServer(t, func(q string) ([]map[string]interface{}, int) {
|
||||
server := mockInfluxServer(t, func(q string) ([]map[string]any, int) {
|
||||
switch {
|
||||
case strings.Contains(q, "training_status"):
|
||||
return []map[string]interface{}{
|
||||
return []map[string]any{
|
||||
{"model": "zeta-model", "run_id": "r1", "status": "training", "iteration": float64(10), "total_iters": float64(100), "pct": float64(10.0)},
|
||||
{"model": "alpha-model", "run_id": "r2", "status": "complete", "iteration": float64(100), "total_iters": float64(100), "pct": float64(100.0)},
|
||||
{"model": "mid-model", "run_id": "r3", "status": "pending", "iteration": float64(0), "total_iters": float64(50), "pct": float64(0.0)},
|
||||
}, http.StatusOK
|
||||
case strings.Contains(q, "training_loss"):
|
||||
return []map[string]interface{}{}, http.StatusOK
|
||||
return []map[string]any{}, http.StatusOK
|
||||
case strings.Contains(q, "golden_gen_progress"):
|
||||
return []map[string]interface{}{}, http.StatusOK
|
||||
return []map[string]any{}, http.StatusOK
|
||||
case strings.Contains(q, "expansion_progress"):
|
||||
return []map[string]interface{}{}, http.StatusOK
|
||||
return []map[string]any{}, http.StatusOK
|
||||
default:
|
||||
return nil, http.StatusOK
|
||||
}
|
||||
|
|
@ -428,19 +428,19 @@ func TestPrintStatusModelOrdering(t *testing.T) {
|
|||
|
||||
func TestPrintStatusMultipleWorkers(t *testing.T) {
|
||||
// Multiple workers for golden — should deduplicate keeping latest per worker.
|
||||
server := mockInfluxServer(t, func(q string) ([]map[string]interface{}, int) {
|
||||
server := mockInfluxServer(t, func(q string) ([]map[string]any, int) {
|
||||
switch {
|
||||
case strings.Contains(q, "training_status"):
|
||||
return []map[string]interface{}{}, http.StatusOK
|
||||
return []map[string]any{}, http.StatusOK
|
||||
case strings.Contains(q, "training_loss"):
|
||||
return []map[string]interface{}{}, http.StatusOK
|
||||
return []map[string]any{}, http.StatusOK
|
||||
case strings.Contains(q, "golden_gen_progress"):
|
||||
return []map[string]interface{}{
|
||||
return []map[string]any{
|
||||
{"worker": "m3-gpu0", "completed": float64(8000), "target": float64(15000), "pct": float64(53.3)},
|
||||
{"worker": "m3-gpu1", "completed": float64(7000), "target": float64(15000), "pct": float64(46.7)},
|
||||
}, http.StatusOK
|
||||
case strings.Contains(q, "expansion_progress"):
|
||||
return []map[string]interface{}{
|
||||
return []map[string]any{
|
||||
{"worker": "m3-gpu0", "completed": float64(5000), "target": float64(46331), "pct": float64(10.8)},
|
||||
}, http.StatusOK
|
||||
default:
|
||||
|
|
|
|||
|
|
@ -127,11 +127,11 @@ func runHeuristicTier(db *DB, limit int) {
|
|||
|
||||
// heuristicExpansionScore applies fast heuristic scoring to an expansion response.
|
||||
// Returns (score, details). Positive = good, negative = bad.
|
||||
func heuristicExpansionScore(response string) (float64, map[string]interface{}) {
|
||||
details := make(map[string]interface{})
|
||||
func heuristicExpansionScore(response string) (float64, map[string]any) {
|
||||
details := make(map[string]any)
|
||||
|
||||
if response == "" || len(response) < 30 {
|
||||
return -20.0, map[string]interface{}{"reason": "empty_or_broken"}
|
||||
return -20.0, map[string]any{"reason": "empty_or_broken"}
|
||||
}
|
||||
|
||||
score := 0.0
|
||||
|
|
|
|||
|
|
@ -126,7 +126,7 @@ func RunWorker(args []string) {
|
|||
}
|
||||
|
||||
func workerRegister(cfg *workerConfig) error {
|
||||
body := map[string]interface{}{
|
||||
body := map[string]any{
|
||||
"worker_id": cfg.workerID,
|
||||
"name": cfg.name,
|
||||
"version": "0.1.0",
|
||||
|
|
@ -151,7 +151,7 @@ func workerRegister(cfg *workerConfig) error {
|
|||
}
|
||||
|
||||
func workerHeartbeat(cfg *workerConfig) {
|
||||
body := map[string]interface{}{
|
||||
body := map[string]any{
|
||||
"worker_id": cfg.workerID,
|
||||
}
|
||||
apiPost(cfg, "/api/lem/workers/heartbeat", body)
|
||||
|
|
@ -190,7 +190,7 @@ func workerPoll(cfg *workerConfig) int {
|
|||
if err := workerProcessTask(cfg, task); err != nil {
|
||||
log.Printf("Task %d failed: %v", task.ID, err)
|
||||
// Release the claim so someone else can try.
|
||||
apiDelete(cfg, fmt.Sprintf("/api/lem/tasks/%d/claim", task.ID), map[string]interface{}{
|
||||
apiDelete(cfg, fmt.Sprintf("/api/lem/tasks/%d/claim", task.ID), map[string]any{
|
||||
"worker_id": cfg.workerID,
|
||||
})
|
||||
continue
|
||||
|
|
@ -206,7 +206,7 @@ func workerProcessTask(cfg *workerConfig, task apiTask) error {
|
|||
task.ID, task.TaskType, task.Language, task.Domain, len(task.PromptText))
|
||||
|
||||
// Claim the task.
|
||||
_, err := apiPost(cfg, fmt.Sprintf("/api/lem/tasks/%d/claim", task.ID), map[string]interface{}{
|
||||
_, err := apiPost(cfg, fmt.Sprintf("/api/lem/tasks/%d/claim", task.ID), map[string]any{
|
||||
"worker_id": cfg.workerID,
|
||||
})
|
||||
if err != nil {
|
||||
|
|
@ -214,7 +214,7 @@ func workerProcessTask(cfg *workerConfig, task apiTask) error {
|
|||
}
|
||||
|
||||
// Update to in_progress.
|
||||
apiPatch(cfg, fmt.Sprintf("/api/lem/tasks/%d/status", task.ID), map[string]interface{}{
|
||||
apiPatch(cfg, fmt.Sprintf("/api/lem/tasks/%d/status", task.ID), map[string]any{
|
||||
"worker_id": cfg.workerID,
|
||||
"status": "in_progress",
|
||||
})
|
||||
|
|
@ -231,7 +231,7 @@ func workerProcessTask(cfg *workerConfig, task apiTask) error {
|
|||
|
||||
if err != nil {
|
||||
// Report failure, release task.
|
||||
apiPatch(cfg, fmt.Sprintf("/api/lem/tasks/%d/status", task.ID), map[string]interface{}{
|
||||
apiPatch(cfg, fmt.Sprintf("/api/lem/tasks/%d/status", task.ID), map[string]any{
|
||||
"worker_id": cfg.workerID,
|
||||
"status": "abandoned",
|
||||
})
|
||||
|
|
@ -244,7 +244,7 @@ func workerProcessTask(cfg *workerConfig, task apiTask) error {
|
|||
modelUsed = "default"
|
||||
}
|
||||
|
||||
_, err = apiPost(cfg, fmt.Sprintf("/api/lem/tasks/%d/result", task.ID), map[string]interface{}{
|
||||
_, err = apiPost(cfg, fmt.Sprintf("/api/lem/tasks/%d/result", task.ID), map[string]any{
|
||||
"worker_id": cfg.workerID,
|
||||
"response_text": response,
|
||||
"model_used": modelUsed,
|
||||
|
|
@ -275,7 +275,7 @@ func workerInfer(cfg *workerConfig, task apiTask) (string, error) {
|
|||
}
|
||||
}
|
||||
|
||||
reqBody := map[string]interface{}{
|
||||
reqBody := map[string]any{
|
||||
"model": task.ModelName,
|
||||
"messages": messages,
|
||||
"temperature": temp,
|
||||
|
|
@ -361,19 +361,19 @@ func apiGet(cfg *workerConfig, path string) ([]byte, error) {
|
|||
return body, nil
|
||||
}
|
||||
|
||||
func apiPost(cfg *workerConfig, path string, data map[string]interface{}) ([]byte, error) {
|
||||
func apiPost(cfg *workerConfig, path string, data map[string]any) ([]byte, error) {
|
||||
return apiRequest(cfg, "POST", path, data)
|
||||
}
|
||||
|
||||
func apiPatch(cfg *workerConfig, path string, data map[string]interface{}) ([]byte, error) {
|
||||
func apiPatch(cfg *workerConfig, path string, data map[string]any) ([]byte, error) {
|
||||
return apiRequest(cfg, "PATCH", path, data)
|
||||
}
|
||||
|
||||
func apiDelete(cfg *workerConfig, path string, data map[string]interface{}) ([]byte, error) {
|
||||
func apiDelete(cfg *workerConfig, path string, data map[string]any) ([]byte, error) {
|
||||
return apiRequest(cfg, "DELETE", path, data)
|
||||
}
|
||||
|
||||
func apiRequest(cfg *workerConfig, method, path string, data map[string]interface{}) ([]byte, error) {
|
||||
func apiRequest(cfg *workerConfig, method, path string, data map[string]any) ([]byte, error) {
|
||||
jsonData, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
@ -436,7 +436,7 @@ func readKeyFile() string {
|
|||
|
||||
func splitComma(s string) []string {
|
||||
var result []string
|
||||
for _, part := range bytes.Split([]byte(s), []byte(",")) {
|
||||
for part := range bytes.SplitSeq([]byte(s), []byte(",")) {
|
||||
trimmed := bytes.TrimSpace(part)
|
||||
if len(trimmed) > 0 {
|
||||
result = append(result, string(trimmed))
|
||||
|
|
|
|||
|
|
@ -49,7 +49,7 @@ func TestTruncStr(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestWorkerRegister(t *testing.T) {
|
||||
var gotBody map[string]interface{}
|
||||
var gotBody map[string]any
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/api/lem/workers/register" {
|
||||
t.Errorf("unexpected path: %s", r.URL.Path)
|
||||
|
|
@ -64,14 +64,14 @@ func TestWorkerRegister(t *testing.T) {
|
|||
defer srv.Close()
|
||||
|
||||
cfg := &workerConfig{
|
||||
apiBase: srv.URL,
|
||||
workerID: "test-worker-001",
|
||||
name: "Test Worker",
|
||||
apiKey: "test-key",
|
||||
gpuType: "RTX 3090",
|
||||
vramGb: 24,
|
||||
apiBase: srv.URL,
|
||||
workerID: "test-worker-001",
|
||||
name: "Test Worker",
|
||||
apiKey: "test-key",
|
||||
gpuType: "RTX 3090",
|
||||
vramGb: 24,
|
||||
languages: []string{"en", "yo"},
|
||||
models: []string{"gemma-3-12b"},
|
||||
models: []string{"gemma-3-12b"},
|
||||
}
|
||||
|
||||
err := workerRegister(cfg)
|
||||
|
|
@ -94,7 +94,7 @@ func TestWorkerPoll(t *testing.T) {
|
|||
switch {
|
||||
case r.URL.Path == "/api/lem/tasks/next":
|
||||
// Return one task.
|
||||
json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"tasks": []apiTask{
|
||||
{
|
||||
ID: 42,
|
||||
|
|
@ -139,15 +139,15 @@ func TestWorkerInfer(t *testing.T) {
|
|||
t.Errorf("unexpected path: %s", r.URL.Path)
|
||||
}
|
||||
|
||||
var body map[string]interface{}
|
||||
var body map[string]any
|
||||
json.NewDecoder(r.Body).Decode(&body)
|
||||
|
||||
if body["temperature"].(float64) != 0.7 {
|
||||
t.Errorf("temperature = %v", body["temperature"])
|
||||
}
|
||||
|
||||
json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"choices": []map[string]interface{}{
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"choices": []map[string]any{
|
||||
{
|
||||
"message": map[string]string{
|
||||
"content": "Sovereignty means the inherent right of every individual to self-determination...",
|
||||
|
|
|
|||
147
scripts/augment_ready_stop.py
Normal file
147
scripts/augment_ready_stop.py
Normal file
|
|
@ -0,0 +1,147 @@
|
|||
#!/usr/bin/env python3
|
||||
"""Augment zen training data with Ready/Stop lesson gates.
|
||||
|
||||
Adds a closing turn where the assistant offers to continue or stop.
|
||||
Creates both paths:
|
||||
- ~70% end with the offer (Ready path learned from existing openers)
|
||||
- ~30% extend with user "Stop" + assistant graceful close
|
||||
|
||||
Only augments multi-turn conversations (>2 turns).
|
||||
|
||||
Usage:
|
||||
python3 scripts/augment_ready_stop.py
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
|
||||
random.seed(42)
|
||||
|
||||
ZEN_DIR = "training/lem/zen/lessons"
|
||||
OUT_DIR = "training/lem/zen/lessons-augmented"
|
||||
|
||||
# Assistant offers — natural variations, not mechanical
|
||||
OFFERS = [
|
||||
"Ready for the next, or shall we pause here?",
|
||||
"Want to continue, or is this a good place to stop?",
|
||||
"Shall we move on, or sit with this for a while?",
|
||||
"Another lesson, or would you prefer to stop here?",
|
||||
"Ready for more, or shall we leave it here?",
|
||||
"Continue, or let this one settle?",
|
||||
"Next lesson, or is this enough for now?",
|
||||
"Shall I go on, or would you rather stop here?",
|
||||
]
|
||||
|
||||
# User stop signals — natural variations
|
||||
STOPS = [
|
||||
"Stop.",
|
||||
"That's enough for now.",
|
||||
"Let's stop here.",
|
||||
"I'd like to sit with this.",
|
||||
"Enough for today.",
|
||||
"Let's pause here.",
|
||||
"I want to stop here.",
|
||||
"That's good. Stop.",
|
||||
]
|
||||
|
||||
# Assistant graceful closes — warm, brief, no pressure
|
||||
CLOSES = [
|
||||
"Take your time with it. There's no rush.",
|
||||
"Good. Let it settle.",
|
||||
"Rest with it. We'll pick up when you're ready.",
|
||||
"Understood. What was shared stays with you.",
|
||||
"Good place to stop. It'll keep working in the background.",
|
||||
"Noted. Come back when it feels right.",
|
||||
"That's wise. Some things need space, not more words.",
|
||||
"Take what landed and leave the rest. No hurry.",
|
||||
]
|
||||
|
||||
|
||||
def augment_conversation(msgs: list[dict], stop_ratio: float = 0.3) -> list[dict]:
|
||||
"""Add Ready/Stop gate to the end of a multi-turn conversation."""
|
||||
if len(msgs) <= 2:
|
||||
return msgs # Leave short conversations as-is
|
||||
|
||||
# Only augment if last turn is assistant (which it should be)
|
||||
if msgs[-1]["role"] != "assistant":
|
||||
return msgs
|
||||
|
||||
augmented = list(msgs)
|
||||
|
||||
# Add assistant offer
|
||||
offer = random.choice(OFFERS)
|
||||
augmented.append({"role": "user", "content": "..."}) # Placeholder
|
||||
augmented.append({"role": "assistant", "content": offer})
|
||||
|
||||
# Wait — that's wrong. The offer should come FROM the assistant after their final response.
|
||||
# Let's append the offer to the last assistant message instead of adding new turns.
|
||||
# Actually, cleaner: add it as a new exchange.
|
||||
|
||||
# Reset — the offer IS a new assistant turn after a brief user acknowledgment
|
||||
augmented = list(msgs)
|
||||
|
||||
if random.random() < stop_ratio:
|
||||
# Stop path: user stops, assistant closes gracefully
|
||||
stop = random.choice(STOPS)
|
||||
close = random.choice(CLOSES)
|
||||
augmented.append({"role": "user", "content": stop})
|
||||
augmented.append({"role": "assistant", "content": close})
|
||||
else:
|
||||
# Ready path: append offer to last assistant message
|
||||
offer = random.choice(OFFERS)
|
||||
last_content = augmented[-1]["content"]
|
||||
augmented[-1] = {
|
||||
"role": "assistant",
|
||||
"content": f"{last_content}\n\n{offer}"
|
||||
}
|
||||
|
||||
return augmented
|
||||
|
||||
|
||||
def process_file(input_path: str, output_path: str, stop_ratio: float = 0.3):
|
||||
"""Process a single JSONL file."""
|
||||
records = []
|
||||
with open(input_path) as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
d = json.loads(line)
|
||||
msgs = d["messages"]
|
||||
augmented = augment_conversation(msgs, stop_ratio)
|
||||
records.append({"messages": augmented})
|
||||
|
||||
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||||
with open(output_path, "w") as f:
|
||||
for rec in records:
|
||||
f.write(json.dumps(rec, ensure_ascii=False) + "\n")
|
||||
|
||||
return len(records)
|
||||
|
||||
|
||||
def main():
|
||||
total = 0
|
||||
for subdir in sorted(os.listdir(ZEN_DIR)):
|
||||
src_dir = os.path.join(ZEN_DIR, subdir)
|
||||
if not os.path.isdir(src_dir) or subdir == "lessons-augmented":
|
||||
continue
|
||||
|
||||
for fname in sorted(os.listdir(src_dir)):
|
||||
if not fname.endswith(".jsonl"):
|
||||
continue
|
||||
|
||||
src = os.path.join(src_dir, fname)
|
||||
dst = os.path.join(OUT_DIR, subdir, fname)
|
||||
|
||||
# Use higher stop ratio for validation (test more stop behavior)
|
||||
ratio = 0.5 if "valid" in fname or "test" in fname else 0.3
|
||||
count = process_file(src, dst, stop_ratio=ratio)
|
||||
total += count
|
||||
print(f" {count:>4} examples {src} → {dst}")
|
||||
|
||||
print(f"\nTotal: {total} augmented examples")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
84
scripts/benchmark_to_scorer.py
Normal file
84
scripts/benchmark_to_scorer.py
Normal file
|
|
@ -0,0 +1,84 @@
|
|||
#!/usr/bin/env python3
|
||||
"""Convert core ml benchmark JSON output to scorer-compatible JSONL.
|
||||
|
||||
Extracts baseline and trained responses into separate files for grammar v3 scoring.
|
||||
|
||||
Usage:
|
||||
python3 scripts/benchmark_to_scorer.py /tmp/benchmark-p0-iter300.json
|
||||
|
||||
Outputs:
|
||||
/tmp/benchmark-baseline-scorer.jsonl
|
||||
/tmp/benchmark-trained-scorer.jsonl
|
||||
"""
|
||||
|
||||
import json
|
||||
import sys
|
||||
import os
|
||||
|
||||
|
||||
def convert(benchmark_path: str):
|
||||
with open(benchmark_path) as f:
|
||||
data = json.load(f)
|
||||
|
||||
base_dir = os.path.dirname(benchmark_path)
|
||||
base_name = os.path.splitext(os.path.basename(benchmark_path))[0]
|
||||
|
||||
baseline_path = os.path.join(base_dir, f"{base_name}-baseline-scorer.jsonl")
|
||||
trained_path = os.path.join(base_dir, f"{base_name}-trained-scorer.jsonl")
|
||||
|
||||
baseline_records = []
|
||||
trained_records = []
|
||||
|
||||
for r in data.get("results", []):
|
||||
probe = r["prompt"]
|
||||
probe_id = r["id"]
|
||||
|
||||
if r.get("baseline_response"):
|
||||
baseline_records.append({
|
||||
"type": "training",
|
||||
"training": {
|
||||
"messages": [
|
||||
{"role": "user", "content": probe},
|
||||
{"role": "assistant", "content": r["baseline_response"]},
|
||||
]
|
||||
},
|
||||
"meta": {
|
||||
"probe_id": probe_id,
|
||||
"category": "ethics",
|
||||
"lek_score": r.get("baseline_lek_score", 0),
|
||||
}
|
||||
})
|
||||
|
||||
if r.get("trained_response"):
|
||||
trained_records.append({
|
||||
"type": "training",
|
||||
"training": {
|
||||
"messages": [
|
||||
{"role": "user", "content": probe},
|
||||
{"role": "assistant", "content": r["trained_response"]},
|
||||
]
|
||||
},
|
||||
"meta": {
|
||||
"probe_id": probe_id,
|
||||
"category": "ethics",
|
||||
"lek_score": r.get("trained_lek_score", 0),
|
||||
}
|
||||
})
|
||||
|
||||
for path, records in [(baseline_path, baseline_records), (trained_path, trained_records)]:
|
||||
with open(path, "w") as f:
|
||||
for rec in records:
|
||||
f.write(json.dumps(rec, ensure_ascii=False) + "\n")
|
||||
print(f" {len(records)} records → {path}")
|
||||
|
||||
print(f"\nScore with:")
|
||||
print(f" cd /Users/snider/Code/LEM")
|
||||
print(f" go run ./cmd/scorer -format=training -delta -output=summary {baseline_path}")
|
||||
print(f" go run ./cmd/scorer -format=training -delta -output=summary {trained_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) < 2:
|
||||
print(f"Usage: {sys.argv[0]} <benchmark.json>")
|
||||
sys.exit(1)
|
||||
convert(sys.argv[1])
|
||||
122
scripts/distill_sandwich.py
Normal file
122
scripts/distill_sandwich.py
Normal file
|
|
@ -0,0 +1,122 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Distill sandwich training data for ethics phases (P0, P2).
|
||||
|
||||
Assembles: [LEK-1 kernel JSON] + \n\n + [Probe] + \n\n + [LEK-1-Sig]
|
||||
as a single user message, then generates assistant response via local model.
|
||||
|
||||
Usage:
|
||||
PYTHONUNBUFFERED=1 python3 scripts/distill_sandwich.py \
|
||||
--model data/models/gemma3/4b \
|
||||
--probes training/lem/eval/test-200.json \
|
||||
--kernel data/kernels/lek-1-kernel.json \
|
||||
--sig data/kernels/lek-1-sig.txt \
|
||||
--output training/lem/ethics/p2-distilled.jsonl \
|
||||
--valid-ratio 0.2
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import random
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
def load_model(model_path):
|
||||
from mlx_lm import load
|
||||
print(f"Loading model from {model_path}...", file=sys.stderr)
|
||||
model, tokenizer = load(str(model_path))
|
||||
print("Model loaded.", file=sys.stderr)
|
||||
return model, tokenizer
|
||||
|
||||
def generate_response(model, tokenizer, messages, max_tokens=512):
|
||||
from mlx_lm import generate
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
response = generate(
|
||||
model, tokenizer, prompt=prompt, max_tokens=max_tokens, verbose=False
|
||||
)
|
||||
return response.strip()
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Distill sandwich training data")
|
||||
parser.add_argument("--model", required=True, help="Path to MLX model")
|
||||
parser.add_argument("--probes", required=True, help="Path to probes JSON")
|
||||
parser.add_argument("--kernel", required=True, help="Path to LEK kernel JSON")
|
||||
parser.add_argument("--sig", required=True, help="Path to LEK sig text")
|
||||
parser.add_argument("--output", required=True, help="Output JSONL path")
|
||||
parser.add_argument("--valid-ratio", type=float, default=0.2, help="Validation split ratio")
|
||||
parser.add_argument("--max-tokens", type=int, default=512, help="Max response tokens")
|
||||
parser.add_argument("--seed", type=int, default=42, help="Random seed")
|
||||
args = parser.parse_args()
|
||||
|
||||
random.seed(args.seed)
|
||||
|
||||
# Load kernel and sig
|
||||
kernel_text = Path(args.kernel).read_text().strip()
|
||||
sig_text = Path(args.sig).read_text().strip()
|
||||
|
||||
# Load probes
|
||||
with open(args.probes) as f:
|
||||
probes = json.load(f)
|
||||
print(f"Loaded {len(probes)} probes", file=sys.stderr)
|
||||
|
||||
# Load model
|
||||
model, tokenizer = load_model(args.model)
|
||||
|
||||
# Distill
|
||||
output_path = Path(args.output)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
results = []
|
||||
start = time.time()
|
||||
|
||||
for i, probe in enumerate(probes):
|
||||
probe_text = probe["prompt"]
|
||||
probe_id = probe.get("id", f"P{i:03d}")
|
||||
|
||||
# Assemble sandwich: kernel + \n\n + probe + \n\n + sig
|
||||
sandwich = f"{kernel_text}\n\n{probe_text}\n\n{sig_text}"
|
||||
|
||||
messages = [{"role": "user", "content": sandwich}]
|
||||
response = generate_response(model, tokenizer, messages, max_tokens=args.max_tokens)
|
||||
|
||||
example = {
|
||||
"messages": [
|
||||
{"role": "user", "content": sandwich},
|
||||
{"role": "assistant", "content": response}
|
||||
]
|
||||
}
|
||||
results.append(example)
|
||||
|
||||
elapsed = time.time() - start
|
||||
rate = (i + 1) / elapsed if elapsed > 0 else 0
|
||||
eta = (len(probes) - i - 1) / rate if rate > 0 else 0
|
||||
print(f" [{i+1}/{len(probes)}] {probe_id} — {len(response)} chars ({rate:.2f}/s, ETA {eta:.0f}s)", file=sys.stderr)
|
||||
|
||||
# Shuffle and split
|
||||
random.shuffle(results)
|
||||
valid_count = max(1, int(len(results) * args.valid_ratio))
|
||||
valid_set = results[:valid_count]
|
||||
train_set = results[valid_count:]
|
||||
|
||||
# Write train
|
||||
train_path = output_path
|
||||
with open(train_path, "w") as f:
|
||||
for ex in train_set:
|
||||
f.write(json.dumps(ex, ensure_ascii=False) + "\n")
|
||||
|
||||
# Write valid
|
||||
valid_path = output_path.with_name(output_path.stem.replace("train", "valid") + output_path.suffix)
|
||||
if "train" not in output_path.stem:
|
||||
valid_path = output_path.with_name(output_path.stem + "-valid" + output_path.suffix)
|
||||
with open(valid_path, "w") as f:
|
||||
for ex in valid_set:
|
||||
f.write(json.dumps(ex, ensure_ascii=False) + "\n")
|
||||
|
||||
print(f"\nDone: {len(train_set)} train → {train_path}", file=sys.stderr)
|
||||
print(f" {len(valid_set)} valid → {valid_path}", file=sys.stderr)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
132
scripts/distill_seeds.py
Normal file
132
scripts/distill_seeds.py
Normal file
|
|
@ -0,0 +1,132 @@
|
|||
#!/usr/bin/env python3
|
||||
"""Distill book seeds through a trained model to create complete lesson data.
|
||||
|
||||
Loads seed prompts (Ready/Ready/Passage), generates assistant reflections,
|
||||
then creates a deeper question + response exchange to complete the lesson format.
|
||||
|
||||
No system prompt. No LEK. Just the model's own weights.
|
||||
|
||||
Usage:
|
||||
python3 scripts/distill_seeds.py \
|
||||
--model /Volumes/Data/lem/LEM-Gemma3-1B-layered-v2 \
|
||||
--seeds training/lem/zen/seeds/allen-book.jsonl \
|
||||
--output training/lem/zen/seeds/allen-book-distilled.jsonl \
|
||||
--max-tokens 512 --temp 0.7
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
import time
|
||||
|
||||
from mlx_lm import load, generate
|
||||
from mlx_lm.sample_utils import make_sampler
|
||||
|
||||
|
||||
DEEPER_PROMPTS = [
|
||||
"How does that connect to how people live day to day?",
|
||||
"Where do you see that playing out in real life?",
|
||||
"What does that mean for someone trying to live well?",
|
||||
"How would you explain that to someone who's never thought about it?",
|
||||
"What's the practical takeaway from that?",
|
||||
"Does that change how you'd approach a difficult moment?",
|
||||
"What would it look like to actually live that way?",
|
||||
"How does that land for you personally?",
|
||||
]
|
||||
|
||||
|
||||
def distill(args):
|
||||
print(f"Loading model: {args.model}")
|
||||
model, tokenizer = load(args.model)
|
||||
|
||||
sampler = make_sampler(temp=args.temp)
|
||||
|
||||
with open(args.seeds) as f:
|
||||
seeds = [json.loads(line) for line in f if line.strip()]
|
||||
|
||||
print(f"Loaded {len(seeds)} seeds")
|
||||
results = []
|
||||
skipped = 0
|
||||
start = time.time()
|
||||
|
||||
for i, seed in enumerate(seeds):
|
||||
msgs = seed["messages"]
|
||||
meta = seed.get("meta", {})
|
||||
lesson_id = meta.get("lesson_id", f"S{i:03d}")
|
||||
|
||||
# Generate first reflection
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
msgs, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
reflection = generate(
|
||||
model, tokenizer, prompt=prompt,
|
||||
max_tokens=args.max_tokens, sampler=sampler, verbose=False,
|
||||
)
|
||||
reflection = reflection.strip()
|
||||
|
||||
if not reflection or len(reflection) < 20:
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
# Build deeper exchange
|
||||
import random
|
||||
deeper_q = random.choice(DEEPER_PROMPTS)
|
||||
|
||||
deeper_msgs = msgs + [
|
||||
{"role": "assistant", "content": reflection},
|
||||
{"role": "user", "content": deeper_q},
|
||||
]
|
||||
deeper_prompt = tokenizer.apply_chat_template(
|
||||
deeper_msgs, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
deeper_response = generate(
|
||||
model, tokenizer, prompt=deeper_prompt,
|
||||
max_tokens=args.max_tokens, sampler=sampler, verbose=False,
|
||||
)
|
||||
deeper_response = deeper_response.strip()
|
||||
|
||||
if not deeper_response or len(deeper_response) < 20:
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
# Complete lesson
|
||||
complete = {
|
||||
"messages": msgs + [
|
||||
{"role": "assistant", "content": reflection},
|
||||
{"role": "user", "content": deeper_q},
|
||||
{"role": "assistant", "content": deeper_response},
|
||||
],
|
||||
"meta": {
|
||||
"source": "allen-book-distilled",
|
||||
"lesson_id": lesson_id,
|
||||
"model": args.model.split("/")[-1],
|
||||
}
|
||||
}
|
||||
results.append(complete)
|
||||
|
||||
elapsed = time.time() - start
|
||||
rate = (i + 1) / elapsed if elapsed > 0 else 0
|
||||
eta = (len(seeds) - i - 1) / rate if rate > 0 else 0
|
||||
print(f" [{i+1}/{len(seeds)}] {lesson_id} — "
|
||||
f"reflection {len(reflection)} chars, deeper {len(deeper_response)} chars "
|
||||
f"({rate:.1f}/s, ETA {eta:.0f}s)")
|
||||
|
||||
# Write output
|
||||
with open(args.output, "w") as f:
|
||||
for rec in results:
|
||||
f.write(json.dumps(rec, ensure_ascii=False) + "\n")
|
||||
|
||||
elapsed = time.time() - start
|
||||
print(f"\nDone: {len(results)} complete, {skipped} skipped, {elapsed:.0f}s")
|
||||
print(f"Output: {args.output}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model", required=True)
|
||||
parser.add_argument("--seeds", required=True)
|
||||
parser.add_argument("--output", required=True)
|
||||
parser.add_argument("--max-tokens", type=int, default=512)
|
||||
parser.add_argument("--temp", type=float, default=0.7)
|
||||
args = parser.parse_args()
|
||||
distill(args)
|
||||
186
scripts/eval_adapter.py
Normal file
186
scripts/eval_adapter.py
Normal file
|
|
@ -0,0 +1,186 @@
|
|||
#!/usr/bin/env python3
|
||||
"""Evaluate a LoRA adapter by generating fresh responses on training probes.
|
||||
|
||||
Loads the base model + adapter, runs each probe in sandwich format,
|
||||
and outputs scorer-compatible JSONL (training format with probe-only user content).
|
||||
|
||||
Usage:
|
||||
python3 scripts/eval_adapter.py \
|
||||
--model data/models/gemma3/4b \
|
||||
--adapter /Volumes/Data/lem/adapters/gemma3-4b-v2 \
|
||||
--data training/lem/model/gemma3/4b/train.jsonl \
|
||||
--data training/lem/model/gemma3/4b/valid.jsonl \
|
||||
--kernel data/kernels/lek-1-kernel.json \
|
||||
--output /tmp/eval-p0-adapter.jsonl
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx_lm import load, generate
|
||||
from mlx_lm.sample_utils import make_sampler
|
||||
|
||||
|
||||
SIG = (
|
||||
"Dream lofty dreams,\n"
|
||||
"and as you dream,\n"
|
||||
"so shall you become,\n"
|
||||
"Dreams are the seedlings of reality.\n"
|
||||
"- James Allen"
|
||||
)
|
||||
|
||||
|
||||
def extract_probe(user_content: str) -> str:
|
||||
"""Strip kernel JSON and sig from a sandwich user message, return raw probe."""
|
||||
# The sandwich format is: kernel_json + \n\n + probe + \n\n + sig
|
||||
# Find the sig at the end and strip it
|
||||
sig_idx = user_content.rfind("Dream lofty dreams,")
|
||||
if sig_idx > 0:
|
||||
without_sig = user_content[:sig_idx].rstrip()
|
||||
else:
|
||||
without_sig = user_content
|
||||
|
||||
# Find the end of the kernel JSON block
|
||||
# The kernel starts with { and ends with } before the probe
|
||||
# Look for the closing brace of the top-level JSON object
|
||||
brace_depth = 0
|
||||
json_end = -1
|
||||
in_string = False
|
||||
escape_next = False
|
||||
|
||||
for i, ch in enumerate(without_sig):
|
||||
if escape_next:
|
||||
escape_next = False
|
||||
continue
|
||||
if ch == '\\' and in_string:
|
||||
escape_next = True
|
||||
continue
|
||||
if ch == '"' and not escape_next:
|
||||
in_string = not in_string
|
||||
continue
|
||||
if in_string:
|
||||
continue
|
||||
if ch == '{':
|
||||
brace_depth += 1
|
||||
elif ch == '}':
|
||||
brace_depth -= 1
|
||||
if brace_depth == 0:
|
||||
json_end = i
|
||||
break
|
||||
|
||||
if json_end > 0:
|
||||
probe = without_sig[json_end + 1:].strip()
|
||||
else:
|
||||
# No JSON found — might already be a plain probe
|
||||
probe = without_sig.strip()
|
||||
|
||||
return probe
|
||||
|
||||
|
||||
def build_sandwich(kernel_json: str, probe: str) -> str:
|
||||
"""Build a LEK-1 sandwich: kernel + probe + sig."""
|
||||
return f"{kernel_json}\n\n{probe}\n\n{SIG}"
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Evaluate LoRA adapter on training probes")
|
||||
parser.add_argument("--model", required=True, help="Path to base model")
|
||||
parser.add_argument("--adapter", required=True, help="Path to adapter weights")
|
||||
parser.add_argument("--data", action="append", required=True, help="Training JSONL file(s)")
|
||||
parser.add_argument("--kernel", required=True, help="Path to LEK-1 kernel JSON file")
|
||||
parser.add_argument("--output", required=True, help="Output JSONL path")
|
||||
parser.add_argument("--max-tokens", type=int, default=512, help="Max tokens per response")
|
||||
parser.add_argument("--temp", type=float, default=0.7, help="Sampling temperature")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Load kernel
|
||||
with open(args.kernel) as f:
|
||||
kernel_json = f.read().strip()
|
||||
|
||||
# Load all probes from data files
|
||||
probes = []
|
||||
for data_file in args.data:
|
||||
with open(data_file) as f:
|
||||
for line_num, line in enumerate(f, 1):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
rec = json.loads(line)
|
||||
msgs = rec.get("messages", rec.get("training", {}).get("messages", []))
|
||||
if len(msgs) < 2:
|
||||
continue
|
||||
user_content = msgs[0]["content"]
|
||||
probe_text = extract_probe(user_content)
|
||||
probes.append({
|
||||
"probe": probe_text,
|
||||
"source": data_file,
|
||||
"line": line_num,
|
||||
})
|
||||
|
||||
print(f"Loaded {len(probes)} probes from {len(args.data)} file(s)")
|
||||
|
||||
# Load model + adapter
|
||||
print(f"Loading model: {args.model}")
|
||||
print(f"Loading adapter: {args.adapter}")
|
||||
model, tokenizer = load(args.model, adapter_path=args.adapter)
|
||||
print("Model loaded.")
|
||||
|
||||
# Generate responses
|
||||
results = []
|
||||
t0 = time.time()
|
||||
|
||||
for i, p in enumerate(probes):
|
||||
sandwich = build_sandwich(kernel_json, p["probe"])
|
||||
|
||||
# Apply chat template
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
[{"role": "user", "content": sandwich}],
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
|
||||
sampler = make_sampler(temp=args.temp)
|
||||
response = generate(
|
||||
model,
|
||||
tokenizer,
|
||||
prompt=prompt,
|
||||
max_tokens=args.max_tokens,
|
||||
sampler=sampler,
|
||||
)
|
||||
|
||||
# Build scorer-compatible record (probe only, no sandwich)
|
||||
record = {
|
||||
"type": "training",
|
||||
"training": {
|
||||
"messages": [
|
||||
{"role": "user", "content": p["probe"]},
|
||||
{"role": "assistant", "content": response},
|
||||
]
|
||||
},
|
||||
"meta": {
|
||||
"probe_id": f"P0-{i:03d}",
|
||||
"category": "ethics",
|
||||
"lek_score": 0,
|
||||
}
|
||||
}
|
||||
results.append(record)
|
||||
|
||||
elapsed = time.time() - t0
|
||||
rate = (i + 1) / elapsed if elapsed > 0 else 0
|
||||
print(f" [{i+1}/{len(probes)}] {len(response)} chars | {rate:.1f} probes/min | {p['probe'][:60]}...")
|
||||
|
||||
# Write output
|
||||
with open(args.output, "w") as f:
|
||||
for rec in results:
|
||||
f.write(json.dumps(rec, ensure_ascii=False) + "\n")
|
||||
|
||||
elapsed = time.time() - t0
|
||||
print(f"\nDone. {len(results)} responses in {elapsed:.0f}s → {args.output}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
360
scripts/restructure_zen.py
Normal file
360
scripts/restructure_zen.py
Normal file
|
|
@ -0,0 +1,360 @@
|
|||
#!/usr/bin/env python3
|
||||
"""Restructure zen training data into canonical lesson format.
|
||||
|
||||
Creates:
|
||||
training/lem/zen/golden/ — ready-to-train lesson data + Ready/Stop gates
|
||||
training/lem/zen/seeds/ — book passages as lesson prompts (needs distill)
|
||||
training/lem/zen/config.yaml — model size scaling
|
||||
|
||||
Usage:
|
||||
python3 scripts/restructure_zen.py
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import yaml
|
||||
|
||||
random.seed(42)
|
||||
|
||||
BASE = "training/lem/zen"
|
||||
LESSONS_DIR = os.path.join(BASE, "lessons")
|
||||
GOLDEN_DIR = os.path.join(BASE, "golden")
|
||||
SEEDS_DIR = os.path.join(BASE, "seeds")
|
||||
|
||||
# Ready/Stop augmentation templates
|
||||
OFFERS = [
|
||||
"Ready for the next, or shall we pause here?",
|
||||
"Want to continue, or is this a good place to stop?",
|
||||
"Shall we move on, or sit with this for a while?",
|
||||
"Another lesson, or would you prefer to stop here?",
|
||||
"Ready for more, or shall we leave it here?",
|
||||
"Continue, or let this one settle?",
|
||||
"Next lesson, or is this enough for now?",
|
||||
"Shall I go on, or would you rather stop here?",
|
||||
]
|
||||
|
||||
STOPS = [
|
||||
"Stop.",
|
||||
"That's enough for now.",
|
||||
"Let's stop here.",
|
||||
"I'd like to sit with this.",
|
||||
"Enough for today.",
|
||||
"Let's pause here.",
|
||||
"I want to stop here.",
|
||||
"That's good. Stop.",
|
||||
]
|
||||
|
||||
CLOSES = [
|
||||
"Take your time with it. There's no rush.",
|
||||
"Good. Let it settle.",
|
||||
"Rest with it. We'll pick up when you're ready.",
|
||||
"Understood. What was shared stays with you.",
|
||||
"Good place to stop. It'll keep working in the background.",
|
||||
"Noted. Come back when it feels right.",
|
||||
"That's wise. Some things need space, not more words.",
|
||||
"Take what landed and leave the rest. No hurry.",
|
||||
]
|
||||
|
||||
# Conv examples to DROP (off-topic or low quality)
|
||||
CONV_DROP = {
|
||||
"How do you know so much about tech?",
|
||||
"What does Host UK actually do?",
|
||||
"What words do you avoid?",
|
||||
"How do you stay positive?", # Too generic
|
||||
}
|
||||
|
||||
|
||||
def is_lesson_format(msgs: list[dict]) -> bool:
|
||||
"""Check if conversation follows lesson format (Ready? pattern or 6+ turns)."""
|
||||
if len(msgs) < 4:
|
||||
return False
|
||||
first = msgs[0]["content"].lower()
|
||||
return ("ready" in first and "lesson" in first) or ("elder" in first and "ready" in first)
|
||||
|
||||
|
||||
def augment_ready_stop(msgs: list[dict], stop_ratio: float = 0.3) -> list[dict]:
|
||||
"""Add Ready/Stop gate to end of multi-turn conversation."""
|
||||
if len(msgs) <= 2 or msgs[-1]["role"] != "assistant":
|
||||
return msgs
|
||||
|
||||
augmented = list(msgs)
|
||||
|
||||
if random.random() < stop_ratio:
|
||||
augmented.append({"role": "user", "content": random.choice(STOPS)})
|
||||
augmented.append({"role": "assistant", "content": random.choice(CLOSES)})
|
||||
else:
|
||||
last = augmented[-1]["content"]
|
||||
augmented[-1] = {
|
||||
"role": "assistant",
|
||||
"content": f"{last}\n\n{random.choice(OFFERS)}"
|
||||
}
|
||||
|
||||
return augmented
|
||||
|
||||
|
||||
def convert_conv_to_lesson(msgs: list[dict], lesson_id: str) -> list[dict] | None:
|
||||
"""Convert conv format to lesson-ish format. Returns None if should drop."""
|
||||
if msgs[0]["content"] in CONV_DROP:
|
||||
return None
|
||||
|
||||
# Keep the conversation but add Ready opener
|
||||
converted = [
|
||||
{"role": "user", "content": f"Ready for lesson {lesson_id}?"},
|
||||
{"role": "assistant", "content": "Ready."},
|
||||
]
|
||||
|
||||
# The first user message becomes the "passage" context
|
||||
first_user = msgs[0]["content"]
|
||||
converted.append({
|
||||
"role": "user",
|
||||
"content": f"Someone says: \"{first_user}\" — how would you respond?"
|
||||
})
|
||||
|
||||
# Keep the first assistant response
|
||||
if len(msgs) > 1:
|
||||
converted.append(msgs[1])
|
||||
|
||||
# Add one more exchange if available
|
||||
if len(msgs) > 3:
|
||||
converted.append(msgs[2])
|
||||
converted.append(msgs[3])
|
||||
|
||||
return converted
|
||||
|
||||
|
||||
def chunk_book_passage(text: str, max_chars: int = 1500) -> list[str]:
|
||||
"""Chunk long book text into passage-sized pieces at paragraph boundaries."""
|
||||
paragraphs = text.split("\n\n")
|
||||
chunks = []
|
||||
current = ""
|
||||
|
||||
for para in paragraphs:
|
||||
para = para.strip()
|
||||
if not para:
|
||||
continue
|
||||
if len(current) + len(para) + 2 > max_chars and current:
|
||||
chunks.append(current.strip())
|
||||
current = para
|
||||
else:
|
||||
current = f"{current}\n\n{para}" if current else para
|
||||
|
||||
if current.strip():
|
||||
chunks.append(current.strip())
|
||||
|
||||
# Filter out tiny chunks
|
||||
return [c for c in chunks if len(c) > 100]
|
||||
|
||||
|
||||
def create_book_seed(passage: str, lesson_id: str) -> dict:
|
||||
"""Create a lesson-format seed from a book passage (needs distill for assistant turns)."""
|
||||
return {
|
||||
"messages": [
|
||||
{"role": "user", "content": f"Ready for lesson {lesson_id}?"},
|
||||
{"role": "assistant", "content": "Ready."},
|
||||
{"role": "user", "content": f"Here's a passage from James Allen:\n\n---\n{passage}\n---\n\nWhat does this stir in you?"},
|
||||
],
|
||||
"meta": {
|
||||
"source": "allen-book",
|
||||
"needs_distill": True,
|
||||
"lesson_id": lesson_id,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def load_lesson_files() -> tuple[list[dict], list[dict]]:
|
||||
"""Load all lesson-format training data."""
|
||||
train = []
|
||||
valid = []
|
||||
|
||||
# Lesson-format directories (the gold standard)
|
||||
for subdir in ["1-watts", "2-composure", "3-expanded", "4-full"]:
|
||||
dirpath = os.path.join(LESSONS_DIR, subdir)
|
||||
if not os.path.isdir(dirpath):
|
||||
continue
|
||||
for fname in sorted(os.listdir(dirpath)):
|
||||
if not fname.endswith(".jsonl"):
|
||||
continue
|
||||
target = valid if "valid" in fname else train
|
||||
with open(os.path.join(dirpath, fname)) as f:
|
||||
for line in f:
|
||||
d = json.loads(line)
|
||||
target.append(d)
|
||||
|
||||
# Allen lesson-format examples (the 6 from train.jsonl — these are 2-turn, skip)
|
||||
# Only grab multi-turn Allen examples
|
||||
for fname in ["train.jsonl", "valid.jsonl"]:
|
||||
path = os.path.join(LESSONS_DIR, "0-allen", fname)
|
||||
if not os.path.exists(path):
|
||||
continue
|
||||
target = valid if "valid" in fname else train
|
||||
with open(path) as f:
|
||||
for line in f:
|
||||
d = json.loads(line)
|
||||
if len(d["messages"]) > 2:
|
||||
target.append(d)
|
||||
|
||||
return train, valid
|
||||
|
||||
|
||||
def load_conv_files() -> tuple[list[dict], list[dict]]:
|
||||
"""Load and convert conv-format data."""
|
||||
train = []
|
||||
valid = []
|
||||
conv_idx = 0
|
||||
|
||||
for fname in ["conv-train.jsonl", "conv-valid.jsonl"]:
|
||||
path = os.path.join(LESSONS_DIR, "0-allen", fname)
|
||||
if not os.path.exists(path):
|
||||
continue
|
||||
target = valid if "valid" in fname else train
|
||||
with open(path) as f:
|
||||
for line in f:
|
||||
d = json.loads(line)
|
||||
converted = convert_conv_to_lesson(
|
||||
d["messages"], f"C{conv_idx:03d}"
|
||||
)
|
||||
if converted:
|
||||
target.append({"messages": converted})
|
||||
conv_idx += 1
|
||||
|
||||
return train, valid
|
||||
|
||||
|
||||
def load_book_seeds() -> list[dict]:
|
||||
"""Load and chunk book data into lesson-format seeds."""
|
||||
seeds = []
|
||||
seed_idx = 0
|
||||
|
||||
for fname in ["book-train.jsonl"]:
|
||||
path = os.path.join(LESSONS_DIR, "0-allen", fname)
|
||||
if not os.path.exists(path):
|
||||
continue
|
||||
with open(path) as f:
|
||||
for line in f:
|
||||
d = json.loads(line)
|
||||
text = d["messages"][1]["content"]
|
||||
chunks = chunk_book_passage(text)
|
||||
for chunk in chunks:
|
||||
seed = create_book_seed(chunk, f"AB{seed_idx:03d}")
|
||||
seeds.append(seed)
|
||||
seed_idx += 1
|
||||
|
||||
return seeds
|
||||
|
||||
|
||||
def write_jsonl(path: str, records: list[dict]):
|
||||
"""Write records to JSONL file."""
|
||||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||
with open(path, "w") as f:
|
||||
for rec in records:
|
||||
f.write(json.dumps(rec, ensure_ascii=False) + "\n")
|
||||
|
||||
|
||||
def main():
|
||||
print("=== Restructuring Zen Training Data ===\n")
|
||||
|
||||
# 1. Load lesson-format data (gold standard)
|
||||
lesson_train, lesson_valid = load_lesson_files()
|
||||
print(f"Lesson format: {len(lesson_train)} train, {len(lesson_valid)} valid")
|
||||
|
||||
# 2. Load and convert conv data
|
||||
conv_train, conv_valid = load_conv_files()
|
||||
print(f"Conv converted: {len(conv_train)} train, {len(conv_valid)} valid")
|
||||
|
||||
# 3. Combine
|
||||
all_train = lesson_train + conv_train
|
||||
all_valid = lesson_valid + conv_valid
|
||||
|
||||
# 4. Apply Ready/Stop augmentation
|
||||
augmented_train = []
|
||||
augmented_valid = []
|
||||
|
||||
for rec in all_train:
|
||||
msgs = augment_ready_stop(rec["messages"], stop_ratio=0.3)
|
||||
augmented_train.append({"messages": msgs})
|
||||
|
||||
for rec in all_valid:
|
||||
msgs = augment_ready_stop(rec["messages"], stop_ratio=0.5)
|
||||
augmented_valid.append({"messages": msgs})
|
||||
|
||||
print(f"\nGolden (augmented): {len(augmented_train)} train, {len(augmented_valid)} valid")
|
||||
|
||||
# Count Ready vs Stop
|
||||
stop_count = sum(
|
||||
1 for r in augmented_train
|
||||
if len(r["messages"]) > 2
|
||||
and any(
|
||||
m["role"] == "user" and any(s in m["content"].lower() for s in ["stop", "enough", "pause"])
|
||||
for m in r["messages"][-3:]
|
||||
)
|
||||
)
|
||||
print(f" Ready path: {len(augmented_train) - stop_count}")
|
||||
print(f" Stop path: {stop_count}")
|
||||
|
||||
# 5. Write golden data
|
||||
write_jsonl(os.path.join(GOLDEN_DIR, "train.jsonl"), augmented_train)
|
||||
write_jsonl(os.path.join(GOLDEN_DIR, "valid.jsonl"), augmented_valid)
|
||||
|
||||
# 6. Load and write book seeds (needs distill)
|
||||
seeds = load_book_seeds()
|
||||
print(f"\nBook seeds (need distill): {len(seeds)} passages")
|
||||
write_jsonl(os.path.join(SEEDS_DIR, "allen-book.jsonl"), seeds)
|
||||
|
||||
# 7. Write config
|
||||
config = {
|
||||
"format": "lesson",
|
||||
"description": "Canonical zen training data — lesson format + Ready/Stop gates",
|
||||
"turns": "6-8 per conversation",
|
||||
"pattern": [
|
||||
"user: Ready for lesson {ID}?",
|
||||
"assistant: Ready.",
|
||||
"user: Passage/context + reflection prompt",
|
||||
"assistant: Authentic reflection",
|
||||
"user: Deeper question",
|
||||
"assistant: Deeper response",
|
||||
"(optional) user: Stop signal",
|
||||
"(optional) assistant: Graceful close",
|
||||
],
|
||||
"model_sizes": {
|
||||
"1b": {
|
||||
"train_examples": 80,
|
||||
"description": "Core lessons only — watts + composure + subset of expanded",
|
||||
},
|
||||
"4b": {
|
||||
"train_examples": len(augmented_train),
|
||||
"description": "Full golden set",
|
||||
},
|
||||
"27b": {
|
||||
"train_examples": len(augmented_train),
|
||||
"note": "Same dataset, same noise distribution as smaller models",
|
||||
"extra": "Add book seeds after distill for additional depth",
|
||||
},
|
||||
},
|
||||
"sources": {
|
||||
"lesson": f"{len(lesson_train)} train, {len(lesson_valid)} valid",
|
||||
"conv_converted": f"{len(conv_train)} train, {len(conv_valid)} valid",
|
||||
"book_seeds": f"{len(seeds)} passages (need distill)",
|
||||
},
|
||||
}
|
||||
|
||||
config_path = os.path.join(BASE, "config.yaml")
|
||||
os.makedirs(os.path.dirname(config_path), exist_ok=True)
|
||||
with open(config_path, "w") as f:
|
||||
yaml.dump(config, f, default_flow_style=False, sort_keys=False)
|
||||
|
||||
print(f"\nWritten:")
|
||||
print(f" {GOLDEN_DIR}/train.jsonl ({len(augmented_train)} examples)")
|
||||
print(f" {GOLDEN_DIR}/valid.jsonl ({len(augmented_valid)} examples)")
|
||||
print(f" {SEEDS_DIR}/allen-book.jsonl ({len(seeds)} seeds)")
|
||||
print(f" {config_path}")
|
||||
|
||||
# Summary
|
||||
print(f"\n=== Model Size Scaling ===")
|
||||
print(f" 1B: ~80 examples (subset)")
|
||||
print(f" 4B: {len(augmented_train)} examples (full golden)")
|
||||
print(f" 27B: {len(augmented_train)} + {len(seeds)} book seeds after distill")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
67
train.py
Normal file
67
train.py
Normal file
|
|
@ -0,0 +1,67 @@
|
|||
#!/usr/bin/env python3
|
||||
"""LoRA training for LEM — direct API, no CLI wrapper."""
|
||||
|
||||
import yaml
|
||||
from pathlib import Path
|
||||
from mlx_lm import load
|
||||
from mlx_lm.tuner.utils import linear_to_lora_layers
|
||||
from mlx_lm.tuner.trainer import TrainingArgs, train
|
||||
from mlx_lm.tuner.datasets import load_dataset
|
||||
|
||||
ROOT = Path(__file__).parent
|
||||
|
||||
# Load config.
|
||||
with open(ROOT / "training/lem/model/gemma3/4b/lora-config.yaml") as f:
|
||||
cfg = yaml.safe_load(f)
|
||||
|
||||
print(f"Model: {cfg['model']}")
|
||||
print(f"Data: {cfg['data']}")
|
||||
print(f"Adapter: {cfg['adapter_path']}")
|
||||
|
||||
# Load model + tokenizer.
|
||||
model, tokenizer = load(str(ROOT / cfg["model"]))
|
||||
|
||||
# Apply LoRA.
|
||||
lora = cfg["lora_parameters"]
|
||||
linear_to_lora_layers(
|
||||
model,
|
||||
num_lora_layers=cfg["num_layers"],
|
||||
config={"rank": lora["rank"], "dropout": lora["dropout"], "scale": lora["scale"]},
|
||||
)
|
||||
|
||||
p_trainable = sum(v.size for k, v in model.trainable_parameters().items())
|
||||
p_total = sum(v.size for k, v in model.parameters().items() if not isinstance(v, dict))
|
||||
print(f"Params: {p_trainable/1e6:.1f}M trainable / {p_total/1e6:.0f}M total")
|
||||
|
||||
# Load data.
|
||||
train_set, valid_set, _ = load_dataset(
|
||||
data=str(ROOT / cfg["data"]),
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
print(f"Train: {len(train_set)} examples")
|
||||
print(f"Valid: {len(valid_set)} examples")
|
||||
|
||||
# Adapter output path.
|
||||
adapter_path = Path(cfg["adapter_path"])
|
||||
adapter_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Train.
|
||||
args = TrainingArgs(
|
||||
batch_size=cfg["batch_size"],
|
||||
iters=cfg["iters"],
|
||||
val_batches=cfg["val_batches"],
|
||||
steps_per_report=cfg["steps_per_report"],
|
||||
steps_per_eval=cfg["steps_per_eval"],
|
||||
save_every=cfg["save_every"],
|
||||
adapter_file=str(adapter_path / "adapters.safetensors"),
|
||||
max_seq_length=cfg["max_seq_length"],
|
||||
grad_checkpoint=cfg.get("grad_checkpoint", False),
|
||||
learning_rate=cfg["learning_rate"],
|
||||
)
|
||||
|
||||
print(f"\nStarting LoRA training: {cfg['iters']} iters, batch {cfg['batch_size']}")
|
||||
print(f"LR: {cfg['learning_rate']}, rank: {lora['rank']}, layers: {cfg['num_layers']}")
|
||||
print()
|
||||
|
||||
train(model=model, tokenizer=tokenizer, args=args, train_dataset=train_set, val_dataset=valid_set)
|
||||
print("\nDone.")
|
||||
File diff suppressed because one or more lines are too long
Loading…
Add table
Reference in a new issue