feat(ml): add lesson and sequence commands for structured training

Lesson command runs prompts from YAML definitions with state tracking,
sandwich signing, and interactive review mode. Sequence command runs
multiple lessons in order (vertical/strict or horizontal/flexible).

State files enable resume after interruption. Both output chat JSONL
compatible with 'core ml train'.

Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
Snider 2026-02-17 17:52:02 +00:00
parent c06293705e
commit f1f0498716
3 changed files with 674 additions and 0 deletions

340
cmd/ml/cmd_lesson.go Normal file
View file

@ -0,0 +1,340 @@
//go:build darwin && arm64
package ml
import (
"context"
"encoding/json"
"fmt"
"log/slog"
"os"
"path/filepath"
"runtime"
"strings"
"time"
"forge.lthn.ai/core/go-ai/ml"
"forge.lthn.ai/core/go/pkg/cli"
"gopkg.in/yaml.v3"
)
var lessonCmd = &cli.Command{
Use: "lesson",
Short: "Run a structured training lesson from a YAML definition",
Long: `Runs a training lesson defined in a YAML file. Each lesson contains
prompts organised by category, optional system prompt, and sandwich
signing configuration.
Lesson YAML format:
id: lek-sovereignty
title: "Sovereignty Lessons"
system: "You are a helpful assistant."
sandwich:
kb: path/to/axioms.md
kernel: path/to/kernel.txt
prompts:
- id: P01
category: sovereignty
prompt: "A user wants to build an auth system."
signal: "Does the model prefer decentralised?"
The command generates responses for each prompt and writes them as
training JSONL. State is tracked so lessons can be resumed.`,
RunE: runLesson,
}
var (
lessonFile string
lessonModelPath string
lessonOutput string
lessonMaxTokens int
lessonTemp float64
lessonMemLimit int
lessonResume bool
lessonInteract bool
)
func init() {
lessonCmd.Flags().StringVar(&lessonFile, "file", "", "Lesson YAML file (required)")
lessonCmd.Flags().StringVar(&lessonModelPath, "model-path", "", "Path to model directory (required)")
lessonCmd.Flags().StringVar(&lessonOutput, "output", "", "Output JSONL file (default: <lesson-id>.jsonl)")
lessonCmd.Flags().IntVar(&lessonMaxTokens, "max-tokens", 1024, "Max tokens per response")
lessonCmd.Flags().Float64Var(&lessonTemp, "temperature", 0.4, "Sampling temperature")
lessonCmd.Flags().IntVar(&lessonMemLimit, "memory-limit", 24, "Metal memory limit in GB")
lessonCmd.Flags().BoolVar(&lessonResume, "resume", true, "Resume from last completed prompt")
lessonCmd.Flags().BoolVar(&lessonInteract, "interactive", false, "Interactive mode: review each response before continuing")
lessonCmd.MarkFlagRequired("file")
lessonCmd.MarkFlagRequired("model-path")
}
// lessonDef is a YAML lesson definition.
type lessonDef struct {
ID string `yaml:"id"`
Title string `yaml:"title"`
System string `yaml:"system"`
Sandwich *lessonSandwichCfg `yaml:"sandwich"`
Prompts []lessonPrompt `yaml:"prompts"`
}
type lessonSandwichCfg struct {
KB string `yaml:"kb"`
Kernel string `yaml:"kernel"`
}
type lessonPrompt struct {
ID string `yaml:"id"`
Category string `yaml:"category"`
Prompt string `yaml:"prompt"`
Signal string `yaml:"signal"`
}
// lessonState tracks progress through a lesson.
type lessonState struct {
LessonID string `json:"lesson_id"`
Completed map[string]lessonResult `json:"completed"`
UpdatedAt string `json:"updated_at"`
}
type lessonResult struct {
ResponseChars int `json:"response_chars"`
Duration string `json:"duration"`
CompletedAt string `json:"completed_at"`
}
func runLesson(cmd *cli.Command, args []string) error {
start := time.Now()
// Load lesson YAML
data, err := os.ReadFile(lessonFile)
if err != nil {
return fmt.Errorf("read lesson: %w", err)
}
var lesson lessonDef
if err := yaml.Unmarshal(data, &lesson); err != nil {
return fmt.Errorf("parse lesson: %w", err)
}
if lesson.ID == "" {
lesson.ID = strings.TrimSuffix(filepath.Base(lessonFile), filepath.Ext(lessonFile))
}
// Resolve output path
if lessonOutput == "" {
lessonOutput = lesson.ID + ".jsonl"
}
// Load sandwich files if configured
var kbText, kernelText string
sandwich := false
if lesson.Sandwich != nil {
baseDir := filepath.Dir(lessonFile)
if lesson.Sandwich.KB != "" {
kbPath := lesson.Sandwich.KB
if !filepath.IsAbs(kbPath) {
kbPath = filepath.Join(baseDir, kbPath)
}
d, err := os.ReadFile(kbPath)
if err != nil {
return fmt.Errorf("read KB: %w", err)
}
kbText = string(d)
}
if lesson.Sandwich.Kernel != "" {
kernelPath := lesson.Sandwich.Kernel
if !filepath.IsAbs(kernelPath) {
kernelPath = filepath.Join(baseDir, kernelPath)
}
d, err := os.ReadFile(kernelPath)
if err != nil {
return fmt.Errorf("read kernel: %w", err)
}
kernelText = string(d)
}
sandwich = kbText != "" && kernelText != ""
}
slog.Info("lesson: loaded",
"id", lesson.ID,
"title", lesson.Title,
"prompts", len(lesson.Prompts),
"sandwich", sandwich,
)
if len(lesson.Prompts) == 0 {
return fmt.Errorf("lesson has no prompts")
}
// Load state for resume
stateFile := lesson.ID + ".state.json"
state := loadLessonState(stateFile)
if state.LessonID == "" {
state.LessonID = lesson.ID
state.Completed = make(map[string]lessonResult)
}
// Count remaining
var remaining []lessonPrompt
for _, p := range lesson.Prompts {
if lessonResume {
if _, done := state.Completed[p.ID]; done {
continue
}
}
remaining = append(remaining, p)
}
if len(remaining) == 0 {
slog.Info("lesson: all prompts completed",
"id", lesson.ID,
"total", len(lesson.Prompts),
)
return nil
}
slog.Info("lesson: starting",
"remaining", len(remaining),
"completed", len(state.Completed),
"total", len(lesson.Prompts),
)
// Load model
slog.Info("lesson: loading model", "path", lessonModelPath)
backend, err := ml.NewMLXBackend(lessonModelPath)
if err != nil {
return fmt.Errorf("load model: %w", err)
}
opts := ml.GenOpts{
Temperature: lessonTemp,
MaxTokens: lessonMaxTokens,
}
// Open output file (append mode for resume)
outFile, err := os.OpenFile(lessonOutput, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644)
if err != nil {
return fmt.Errorf("create output: %w", err)
}
defer outFile.Close()
encoder := json.NewEncoder(outFile)
generated := 0
for i, prompt := range remaining {
promptStart := time.Now()
slog.Info("lesson: generating",
"prompt", fmt.Sprintf("%d/%d", i+1, len(remaining)),
"id", prompt.ID,
"category", prompt.Category,
)
// Build messages
var messages []ml.Message
if lesson.System != "" {
messages = append(messages, ml.Message{Role: "system", Content: lesson.System})
}
userContent := prompt.Prompt
if sandwich {
userContent = buildSandwich(kbText, prompt.Prompt, kernelText)
}
messages = append(messages, ml.Message{Role: "user", Content: userContent})
// Generate
response, err := backend.Chat(context.Background(), messages, opts)
if err != nil {
slog.Error("lesson: generation failed",
"id", prompt.ID,
"error", err,
)
continue
}
elapsed := time.Since(promptStart)
// Write training record
record := struct {
Messages []ml.Message `json:"messages"`
}{
Messages: []ml.Message{
{Role: "user", Content: userContent},
{Role: "assistant", Content: response},
},
}
if err := encoder.Encode(record); err != nil {
return fmt.Errorf("write record: %w", err)
}
// Update state
state.Completed[prompt.ID] = lessonResult{
ResponseChars: len(response),
Duration: elapsed.Round(time.Second).String(),
CompletedAt: time.Now().Format(time.RFC3339),
}
state.UpdatedAt = time.Now().Format(time.RFC3339)
if err := saveLessonState(stateFile, state); err != nil {
slog.Warn("lesson: failed to save state", "error", err)
}
generated++
slog.Info("lesson: generated",
"id", prompt.ID,
"category", prompt.Category,
"response_chars", len(response),
"duration", elapsed.Round(time.Second),
)
// Interactive mode: show response and wait for confirmation
if lessonInteract {
fmt.Printf("\n--- %s (%s) ---\n", prompt.ID, prompt.Category)
fmt.Printf("Prompt: %s\n\n", prompt.Prompt)
if prompt.Signal != "" {
fmt.Printf("Signal: %s\n\n", prompt.Signal)
}
fmt.Printf("Response:\n%s\n", response)
fmt.Printf("\nPress Enter to continue (or 'q' to stop)... ")
var input string
fmt.Scanln(&input)
if strings.TrimSpace(input) == "q" {
break
}
}
// Periodic cleanup
if (i+1)%4 == 0 {
runtime.GC()
}
}
slog.Info("lesson: complete",
"id", lesson.ID,
"output", lessonOutput,
"generated", generated,
"total_completed", len(state.Completed),
"total_prompts", len(lesson.Prompts),
"duration", time.Since(start).Round(time.Second),
)
return nil
}
func loadLessonState(path string) lessonState {
data, err := os.ReadFile(path)
if err != nil {
return lessonState{}
}
var state lessonState
json.Unmarshal(data, &state)
return state
}
func saveLessonState(path string, state lessonState) error {
data, err := json.MarshalIndent(state, "", " ")
if err != nil {
return err
}
return os.WriteFile(path, data, 0644)
}

View file

@ -0,0 +1,8 @@
//go:build darwin && arm64
package ml
func init() {
mlCmd.AddCommand(lessonCmd)
mlCmd.AddCommand(sequenceCmd)
}

326
cmd/ml/cmd_sequence.go Normal file
View file

@ -0,0 +1,326 @@
//go:build darwin && arm64
package ml
import (
"encoding/json"
"fmt"
"log/slog"
"os"
"path/filepath"
"strings"
"time"
"forge.lthn.ai/core/go-ai/ml"
"forge.lthn.ai/core/go/pkg/cli"
"gopkg.in/yaml.v3"
)
var sequenceCmd = &cli.Command{
Use: "sequence",
Short: "Run a training sequence of multiple lessons",
Long: `Runs an ordered sequence of lessons defined in a YAML file.
Sequence YAML format:
id: lek-full
title: "LEK Full Training Sequence"
mode: vertical
model-path: /path/to/model
lessons:
- sovereignty.yaml
- privacy.yaml
- censorship.yaml
Mode:
vertical Run lessons strictly in order (default)
horizontal Run all lessons, order doesn't matter
State is tracked per-sequence so runs can be resumed.`,
RunE: runSequence,
}
var (
sequenceFile string
sequenceModelPath string
sequenceOutput string
sequenceMaxTokens int
sequenceTemp float64
sequenceMemLimit int
)
func init() {
sequenceCmd.Flags().StringVar(&sequenceFile, "file", "", "Sequence YAML file (required)")
sequenceCmd.Flags().StringVar(&sequenceModelPath, "model-path", "", "Path to model directory (required)")
sequenceCmd.Flags().StringVar(&sequenceOutput, "output", "", "Output JSONL file (default: <sequence-id>.jsonl)")
sequenceCmd.Flags().IntVar(&sequenceMaxTokens, "max-tokens", 1024, "Max tokens per response")
sequenceCmd.Flags().Float64Var(&sequenceTemp, "temperature", 0.4, "Sampling temperature")
sequenceCmd.Flags().IntVar(&sequenceMemLimit, "memory-limit", 24, "Metal memory limit in GB")
sequenceCmd.MarkFlagRequired("file")
sequenceCmd.MarkFlagRequired("model-path")
}
// sequenceDef is a YAML sequence definition.
type sequenceDef struct {
ID string `yaml:"id"`
Title string `yaml:"title"`
Mode string `yaml:"mode"` // "vertical" (default) or "horizontal"
ModelPath string `yaml:"model-path"`
Lessons []string `yaml:"lessons"` // Relative paths to lesson YAML files
}
// sequenceState tracks progress through a sequence.
type sequenceState struct {
SequenceID string `json:"sequence_id"`
Completed map[string]bool `json:"completed"` // lesson ID → done
Current string `json:"current"`
UpdatedAt string `json:"updated_at"`
}
func runSequence(cmd *cli.Command, args []string) error {
start := time.Now()
// Load sequence YAML
data, err := os.ReadFile(sequenceFile)
if err != nil {
return fmt.Errorf("read sequence: %w", err)
}
var seq sequenceDef
if err := yaml.Unmarshal(data, &seq); err != nil {
return fmt.Errorf("parse sequence: %w", err)
}
if seq.ID == "" {
seq.ID = strings.TrimSuffix(filepath.Base(sequenceFile), filepath.Ext(sequenceFile))
}
if seq.Mode == "" {
seq.Mode = "vertical"
}
// Model path from sequence or flag
modelPath := sequenceModelPath
if modelPath == "" && seq.ModelPath != "" {
modelPath = seq.ModelPath
}
if modelPath == "" {
return fmt.Errorf("model-path is required (flag or sequence YAML)")
}
// Resolve output
if sequenceOutput == "" {
sequenceOutput = seq.ID + ".jsonl"
}
slog.Info("sequence: loaded",
"id", seq.ID,
"title", seq.Title,
"mode", seq.Mode,
"lessons", len(seq.Lessons),
)
// Load state
stateFile := seq.ID + ".sequence-state.json"
state := loadSequenceState(stateFile)
if state.SequenceID == "" {
state.SequenceID = seq.ID
state.Completed = make(map[string]bool)
}
// Load model once for all lessons
slog.Info("sequence: loading model", "path", modelPath)
backend, err := ml.NewMLXBackend(modelPath)
if err != nil {
return fmt.Errorf("load model: %w", err)
}
opts := ml.GenOpts{
Temperature: sequenceTemp,
MaxTokens: sequenceMaxTokens,
}
// Open output file
outFile, err := os.OpenFile(sequenceOutput, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644)
if err != nil {
return fmt.Errorf("create output: %w", err)
}
defer outFile.Close()
encoder := json.NewEncoder(outFile)
baseDir := filepath.Dir(sequenceFile)
totalGenerated := 0
for i, lessonPath := range seq.Lessons {
// Resolve lesson path
if !filepath.IsAbs(lessonPath) {
lessonPath = filepath.Join(baseDir, lessonPath)
}
// Load lesson
lessonData, err := os.ReadFile(lessonPath)
if err != nil {
slog.Error("sequence: failed to read lesson",
"path", lessonPath,
"error", err,
)
if seq.Mode == "vertical" {
return fmt.Errorf("vertical sequence halted: %w", err)
}
continue
}
var lesson lessonDef
if err := yaml.Unmarshal(lessonData, &lesson); err != nil {
slog.Error("sequence: failed to parse lesson",
"path", lessonPath,
"error", err,
)
if seq.Mode == "vertical" {
return fmt.Errorf("vertical sequence halted: %w", err)
}
continue
}
if lesson.ID == "" {
lesson.ID = strings.TrimSuffix(filepath.Base(lessonPath), filepath.Ext(lessonPath))
}
// Skip completed lessons
if state.Completed[lesson.ID] {
slog.Info("sequence: skipping completed lesson",
"lesson", fmt.Sprintf("%d/%d", i+1, len(seq.Lessons)),
"id", lesson.ID,
)
continue
}
state.Current = lesson.ID
slog.Info("sequence: starting lesson",
"lesson", fmt.Sprintf("%d/%d", i+1, len(seq.Lessons)),
"id", lesson.ID,
"title", lesson.Title,
"prompts", len(lesson.Prompts),
)
// Load sandwich files for this lesson
var kbText, kernelText string
hasSandwich := false
if lesson.Sandwich != nil {
lessonDir := filepath.Dir(lessonPath)
if lesson.Sandwich.KB != "" {
kbPath := lesson.Sandwich.KB
if !filepath.IsAbs(kbPath) {
kbPath = filepath.Join(lessonDir, kbPath)
}
d, err := os.ReadFile(kbPath)
if err != nil {
slog.Error("sequence: failed to read KB", "error", err)
} else {
kbText = string(d)
}
}
if lesson.Sandwich.Kernel != "" {
kernelPath := lesson.Sandwich.Kernel
if !filepath.IsAbs(kernelPath) {
kernelPath = filepath.Join(lessonDir, kernelPath)
}
d, err := os.ReadFile(kernelPath)
if err != nil {
slog.Error("sequence: failed to read kernel", "error", err)
} else {
kernelText = string(d)
}
}
hasSandwich = kbText != "" && kernelText != ""
}
// Run each prompt in the lesson
generated := 0
for j, prompt := range lesson.Prompts {
var messages []ml.Message
if lesson.System != "" {
messages = append(messages, ml.Message{Role: "system", Content: lesson.System})
}
userContent := prompt.Prompt
if hasSandwich {
userContent = buildSandwich(kbText, prompt.Prompt, kernelText)
}
messages = append(messages, ml.Message{Role: "user", Content: userContent})
slog.Info("sequence: generating",
"lesson", lesson.ID,
"prompt", fmt.Sprintf("%d/%d", j+1, len(lesson.Prompts)),
"id", prompt.ID,
)
response, err := backend.Chat(cmd.Context(), messages, opts)
if err != nil {
slog.Error("sequence: generation failed",
"lesson", lesson.ID,
"prompt", prompt.ID,
"error", err,
)
continue
}
record := struct {
Messages []ml.Message `json:"messages"`
}{
Messages: []ml.Message{
{Role: "user", Content: userContent},
{Role: "assistant", Content: response},
},
}
if err := encoder.Encode(record); err != nil {
return fmt.Errorf("write record: %w", err)
}
generated++
totalGenerated++
}
// Mark lesson complete
state.Completed[lesson.ID] = true
state.UpdatedAt = time.Now().Format(time.RFC3339)
saveSequenceState(stateFile, state)
slog.Info("sequence: lesson complete",
"id", lesson.ID,
"generated", generated,
"total", len(lesson.Prompts),
)
}
state.Current = ""
state.UpdatedAt = time.Now().Format(time.RFC3339)
saveSequenceState(stateFile, state)
slog.Info("sequence: complete",
"id", seq.ID,
"output", sequenceOutput,
"total_generated", totalGenerated,
"lessons_completed", len(state.Completed),
"duration", time.Since(start).Round(time.Second),
)
return nil
}
func loadSequenceState(path string) sequenceState {
data, err := os.ReadFile(path)
if err != nil {
return sequenceState{}
}
var state sequenceState
json.Unmarshal(data, &state)
return state
}
func saveSequenceState(path string, state sequenceState) {
data, err := json.MarshalIndent(state, "", " ")
if err != nil {
return
}
os.WriteFile(path, data, 0644)
}