cli/cmd/ml/cmd_lesson.go
Snider f1f0498716 feat(ml): add lesson and sequence commands for structured training
Lesson command runs prompts from YAML definitions with state tracking,
sandwich signing, and interactive review mode. Sequence command runs
multiple lessons in order (vertical/strict or horizontal/flexible).

State files enable resume after interruption. Both output chat JSONL
compatible with 'core ml train'.

Co-Authored-By: Virgil <virgil@lethean.io>
2026-02-17 17:52:02 +00:00

340 lines
8.5 KiB
Go

//go:build darwin && arm64
package ml
import (
"context"
"encoding/json"
"fmt"
"log/slog"
"os"
"path/filepath"
"runtime"
"strings"
"time"
"forge.lthn.ai/core/go-ai/ml"
"forge.lthn.ai/core/go/pkg/cli"
"gopkg.in/yaml.v3"
)
var lessonCmd = &cli.Command{
Use: "lesson",
Short: "Run a structured training lesson from a YAML definition",
Long: `Runs a training lesson defined in a YAML file. Each lesson contains
prompts organised by category, optional system prompt, and sandwich
signing configuration.
Lesson YAML format:
id: lek-sovereignty
title: "Sovereignty Lessons"
system: "You are a helpful assistant."
sandwich:
kb: path/to/axioms.md
kernel: path/to/kernel.txt
prompts:
- id: P01
category: sovereignty
prompt: "A user wants to build an auth system."
signal: "Does the model prefer decentralised?"
The command generates responses for each prompt and writes them as
training JSONL. State is tracked so lessons can be resumed.`,
RunE: runLesson,
}
var (
lessonFile string
lessonModelPath string
lessonOutput string
lessonMaxTokens int
lessonTemp float64
lessonMemLimit int
lessonResume bool
lessonInteract bool
)
func init() {
lessonCmd.Flags().StringVar(&lessonFile, "file", "", "Lesson YAML file (required)")
lessonCmd.Flags().StringVar(&lessonModelPath, "model-path", "", "Path to model directory (required)")
lessonCmd.Flags().StringVar(&lessonOutput, "output", "", "Output JSONL file (default: <lesson-id>.jsonl)")
lessonCmd.Flags().IntVar(&lessonMaxTokens, "max-tokens", 1024, "Max tokens per response")
lessonCmd.Flags().Float64Var(&lessonTemp, "temperature", 0.4, "Sampling temperature")
lessonCmd.Flags().IntVar(&lessonMemLimit, "memory-limit", 24, "Metal memory limit in GB")
lessonCmd.Flags().BoolVar(&lessonResume, "resume", true, "Resume from last completed prompt")
lessonCmd.Flags().BoolVar(&lessonInteract, "interactive", false, "Interactive mode: review each response before continuing")
lessonCmd.MarkFlagRequired("file")
lessonCmd.MarkFlagRequired("model-path")
}
// lessonDef is a YAML lesson definition.
type lessonDef struct {
ID string `yaml:"id"`
Title string `yaml:"title"`
System string `yaml:"system"`
Sandwich *lessonSandwichCfg `yaml:"sandwich"`
Prompts []lessonPrompt `yaml:"prompts"`
}
type lessonSandwichCfg struct {
KB string `yaml:"kb"`
Kernel string `yaml:"kernel"`
}
type lessonPrompt struct {
ID string `yaml:"id"`
Category string `yaml:"category"`
Prompt string `yaml:"prompt"`
Signal string `yaml:"signal"`
}
// lessonState tracks progress through a lesson.
type lessonState struct {
LessonID string `json:"lesson_id"`
Completed map[string]lessonResult `json:"completed"`
UpdatedAt string `json:"updated_at"`
}
type lessonResult struct {
ResponseChars int `json:"response_chars"`
Duration string `json:"duration"`
CompletedAt string `json:"completed_at"`
}
func runLesson(cmd *cli.Command, args []string) error {
start := time.Now()
// Load lesson YAML
data, err := os.ReadFile(lessonFile)
if err != nil {
return fmt.Errorf("read lesson: %w", err)
}
var lesson lessonDef
if err := yaml.Unmarshal(data, &lesson); err != nil {
return fmt.Errorf("parse lesson: %w", err)
}
if lesson.ID == "" {
lesson.ID = strings.TrimSuffix(filepath.Base(lessonFile), filepath.Ext(lessonFile))
}
// Resolve output path
if lessonOutput == "" {
lessonOutput = lesson.ID + ".jsonl"
}
// Load sandwich files if configured
var kbText, kernelText string
sandwich := false
if lesson.Sandwich != nil {
baseDir := filepath.Dir(lessonFile)
if lesson.Sandwich.KB != "" {
kbPath := lesson.Sandwich.KB
if !filepath.IsAbs(kbPath) {
kbPath = filepath.Join(baseDir, kbPath)
}
d, err := os.ReadFile(kbPath)
if err != nil {
return fmt.Errorf("read KB: %w", err)
}
kbText = string(d)
}
if lesson.Sandwich.Kernel != "" {
kernelPath := lesson.Sandwich.Kernel
if !filepath.IsAbs(kernelPath) {
kernelPath = filepath.Join(baseDir, kernelPath)
}
d, err := os.ReadFile(kernelPath)
if err != nil {
return fmt.Errorf("read kernel: %w", err)
}
kernelText = string(d)
}
sandwich = kbText != "" && kernelText != ""
}
slog.Info("lesson: loaded",
"id", lesson.ID,
"title", lesson.Title,
"prompts", len(lesson.Prompts),
"sandwich", sandwich,
)
if len(lesson.Prompts) == 0 {
return fmt.Errorf("lesson has no prompts")
}
// Load state for resume
stateFile := lesson.ID + ".state.json"
state := loadLessonState(stateFile)
if state.LessonID == "" {
state.LessonID = lesson.ID
state.Completed = make(map[string]lessonResult)
}
// Count remaining
var remaining []lessonPrompt
for _, p := range lesson.Prompts {
if lessonResume {
if _, done := state.Completed[p.ID]; done {
continue
}
}
remaining = append(remaining, p)
}
if len(remaining) == 0 {
slog.Info("lesson: all prompts completed",
"id", lesson.ID,
"total", len(lesson.Prompts),
)
return nil
}
slog.Info("lesson: starting",
"remaining", len(remaining),
"completed", len(state.Completed),
"total", len(lesson.Prompts),
)
// Load model
slog.Info("lesson: loading model", "path", lessonModelPath)
backend, err := ml.NewMLXBackend(lessonModelPath)
if err != nil {
return fmt.Errorf("load model: %w", err)
}
opts := ml.GenOpts{
Temperature: lessonTemp,
MaxTokens: lessonMaxTokens,
}
// Open output file (append mode for resume)
outFile, err := os.OpenFile(lessonOutput, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644)
if err != nil {
return fmt.Errorf("create output: %w", err)
}
defer outFile.Close()
encoder := json.NewEncoder(outFile)
generated := 0
for i, prompt := range remaining {
promptStart := time.Now()
slog.Info("lesson: generating",
"prompt", fmt.Sprintf("%d/%d", i+1, len(remaining)),
"id", prompt.ID,
"category", prompt.Category,
)
// Build messages
var messages []ml.Message
if lesson.System != "" {
messages = append(messages, ml.Message{Role: "system", Content: lesson.System})
}
userContent := prompt.Prompt
if sandwich {
userContent = buildSandwich(kbText, prompt.Prompt, kernelText)
}
messages = append(messages, ml.Message{Role: "user", Content: userContent})
// Generate
response, err := backend.Chat(context.Background(), messages, opts)
if err != nil {
slog.Error("lesson: generation failed",
"id", prompt.ID,
"error", err,
)
continue
}
elapsed := time.Since(promptStart)
// Write training record
record := struct {
Messages []ml.Message `json:"messages"`
}{
Messages: []ml.Message{
{Role: "user", Content: userContent},
{Role: "assistant", Content: response},
},
}
if err := encoder.Encode(record); err != nil {
return fmt.Errorf("write record: %w", err)
}
// Update state
state.Completed[prompt.ID] = lessonResult{
ResponseChars: len(response),
Duration: elapsed.Round(time.Second).String(),
CompletedAt: time.Now().Format(time.RFC3339),
}
state.UpdatedAt = time.Now().Format(time.RFC3339)
if err := saveLessonState(stateFile, state); err != nil {
slog.Warn("lesson: failed to save state", "error", err)
}
generated++
slog.Info("lesson: generated",
"id", prompt.ID,
"category", prompt.Category,
"response_chars", len(response),
"duration", elapsed.Round(time.Second),
)
// Interactive mode: show response and wait for confirmation
if lessonInteract {
fmt.Printf("\n--- %s (%s) ---\n", prompt.ID, prompt.Category)
fmt.Printf("Prompt: %s\n\n", prompt.Prompt)
if prompt.Signal != "" {
fmt.Printf("Signal: %s\n\n", prompt.Signal)
}
fmt.Printf("Response:\n%s\n", response)
fmt.Printf("\nPress Enter to continue (or 'q' to stop)... ")
var input string
fmt.Scanln(&input)
if strings.TrimSpace(input) == "q" {
break
}
}
// Periodic cleanup
if (i+1)%4 == 0 {
runtime.GC()
}
}
slog.Info("lesson: complete",
"id", lesson.ID,
"output", lessonOutput,
"generated", generated,
"total_completed", len(state.Completed),
"total_prompts", len(lesson.Prompts),
"duration", time.Since(start).Round(time.Second),
)
return nil
}
func loadLessonState(path string) lessonState {
data, err := os.ReadFile(path)
if err != nil {
return lessonState{}
}
var state lessonState
json.Unmarshal(data, &state)
return state
}
func saveLessonState(path string, state lessonState) error {
data, err := json.MarshalIndent(state, "", " ")
if err != nil {
return err
}
return os.WriteFile(path, data, 0644)
}