cli/cmd/ml/cmd_sequence.go
Snider f1f0498716 feat(ml): add lesson and sequence commands for structured training
Lesson command runs prompts from YAML definitions with state tracking,
sandwich signing, and interactive review mode. Sequence command runs
multiple lessons in order (vertical/strict or horizontal/flexible).

State files enable resume after interruption. Both output chat JSONL
compatible with 'core ml train'.

Co-Authored-By: Virgil <virgil@lethean.io>
2026-02-17 17:52:02 +00:00

326 lines
8.1 KiB
Go

//go:build darwin && arm64
package ml
import (
"encoding/json"
"fmt"
"log/slog"
"os"
"path/filepath"
"strings"
"time"
"forge.lthn.ai/core/go-ai/ml"
"forge.lthn.ai/core/go/pkg/cli"
"gopkg.in/yaml.v3"
)
var sequenceCmd = &cli.Command{
Use: "sequence",
Short: "Run a training sequence of multiple lessons",
Long: `Runs an ordered sequence of lessons defined in a YAML file.
Sequence YAML format:
id: lek-full
title: "LEK Full Training Sequence"
mode: vertical
model-path: /path/to/model
lessons:
- sovereignty.yaml
- privacy.yaml
- censorship.yaml
Mode:
vertical Run lessons strictly in order (default)
horizontal Run all lessons, order doesn't matter
State is tracked per-sequence so runs can be resumed.`,
RunE: runSequence,
}
var (
sequenceFile string
sequenceModelPath string
sequenceOutput string
sequenceMaxTokens int
sequenceTemp float64
sequenceMemLimit int
)
func init() {
sequenceCmd.Flags().StringVar(&sequenceFile, "file", "", "Sequence YAML file (required)")
sequenceCmd.Flags().StringVar(&sequenceModelPath, "model-path", "", "Path to model directory (required)")
sequenceCmd.Flags().StringVar(&sequenceOutput, "output", "", "Output JSONL file (default: <sequence-id>.jsonl)")
sequenceCmd.Flags().IntVar(&sequenceMaxTokens, "max-tokens", 1024, "Max tokens per response")
sequenceCmd.Flags().Float64Var(&sequenceTemp, "temperature", 0.4, "Sampling temperature")
sequenceCmd.Flags().IntVar(&sequenceMemLimit, "memory-limit", 24, "Metal memory limit in GB")
sequenceCmd.MarkFlagRequired("file")
sequenceCmd.MarkFlagRequired("model-path")
}
// sequenceDef is a YAML sequence definition.
type sequenceDef struct {
ID string `yaml:"id"`
Title string `yaml:"title"`
Mode string `yaml:"mode"` // "vertical" (default) or "horizontal"
ModelPath string `yaml:"model-path"`
Lessons []string `yaml:"lessons"` // Relative paths to lesson YAML files
}
// sequenceState tracks progress through a sequence.
type sequenceState struct {
SequenceID string `json:"sequence_id"`
Completed map[string]bool `json:"completed"` // lesson ID → done
Current string `json:"current"`
UpdatedAt string `json:"updated_at"`
}
func runSequence(cmd *cli.Command, args []string) error {
start := time.Now()
// Load sequence YAML
data, err := os.ReadFile(sequenceFile)
if err != nil {
return fmt.Errorf("read sequence: %w", err)
}
var seq sequenceDef
if err := yaml.Unmarshal(data, &seq); err != nil {
return fmt.Errorf("parse sequence: %w", err)
}
if seq.ID == "" {
seq.ID = strings.TrimSuffix(filepath.Base(sequenceFile), filepath.Ext(sequenceFile))
}
if seq.Mode == "" {
seq.Mode = "vertical"
}
// Model path from sequence or flag
modelPath := sequenceModelPath
if modelPath == "" && seq.ModelPath != "" {
modelPath = seq.ModelPath
}
if modelPath == "" {
return fmt.Errorf("model-path is required (flag or sequence YAML)")
}
// Resolve output
if sequenceOutput == "" {
sequenceOutput = seq.ID + ".jsonl"
}
slog.Info("sequence: loaded",
"id", seq.ID,
"title", seq.Title,
"mode", seq.Mode,
"lessons", len(seq.Lessons),
)
// Load state
stateFile := seq.ID + ".sequence-state.json"
state := loadSequenceState(stateFile)
if state.SequenceID == "" {
state.SequenceID = seq.ID
state.Completed = make(map[string]bool)
}
// Load model once for all lessons
slog.Info("sequence: loading model", "path", modelPath)
backend, err := ml.NewMLXBackend(modelPath)
if err != nil {
return fmt.Errorf("load model: %w", err)
}
opts := ml.GenOpts{
Temperature: sequenceTemp,
MaxTokens: sequenceMaxTokens,
}
// Open output file
outFile, err := os.OpenFile(sequenceOutput, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644)
if err != nil {
return fmt.Errorf("create output: %w", err)
}
defer outFile.Close()
encoder := json.NewEncoder(outFile)
baseDir := filepath.Dir(sequenceFile)
totalGenerated := 0
for i, lessonPath := range seq.Lessons {
// Resolve lesson path
if !filepath.IsAbs(lessonPath) {
lessonPath = filepath.Join(baseDir, lessonPath)
}
// Load lesson
lessonData, err := os.ReadFile(lessonPath)
if err != nil {
slog.Error("sequence: failed to read lesson",
"path", lessonPath,
"error", err,
)
if seq.Mode == "vertical" {
return fmt.Errorf("vertical sequence halted: %w", err)
}
continue
}
var lesson lessonDef
if err := yaml.Unmarshal(lessonData, &lesson); err != nil {
slog.Error("sequence: failed to parse lesson",
"path", lessonPath,
"error", err,
)
if seq.Mode == "vertical" {
return fmt.Errorf("vertical sequence halted: %w", err)
}
continue
}
if lesson.ID == "" {
lesson.ID = strings.TrimSuffix(filepath.Base(lessonPath), filepath.Ext(lessonPath))
}
// Skip completed lessons
if state.Completed[lesson.ID] {
slog.Info("sequence: skipping completed lesson",
"lesson", fmt.Sprintf("%d/%d", i+1, len(seq.Lessons)),
"id", lesson.ID,
)
continue
}
state.Current = lesson.ID
slog.Info("sequence: starting lesson",
"lesson", fmt.Sprintf("%d/%d", i+1, len(seq.Lessons)),
"id", lesson.ID,
"title", lesson.Title,
"prompts", len(lesson.Prompts),
)
// Load sandwich files for this lesson
var kbText, kernelText string
hasSandwich := false
if lesson.Sandwich != nil {
lessonDir := filepath.Dir(lessonPath)
if lesson.Sandwich.KB != "" {
kbPath := lesson.Sandwich.KB
if !filepath.IsAbs(kbPath) {
kbPath = filepath.Join(lessonDir, kbPath)
}
d, err := os.ReadFile(kbPath)
if err != nil {
slog.Error("sequence: failed to read KB", "error", err)
} else {
kbText = string(d)
}
}
if lesson.Sandwich.Kernel != "" {
kernelPath := lesson.Sandwich.Kernel
if !filepath.IsAbs(kernelPath) {
kernelPath = filepath.Join(lessonDir, kernelPath)
}
d, err := os.ReadFile(kernelPath)
if err != nil {
slog.Error("sequence: failed to read kernel", "error", err)
} else {
kernelText = string(d)
}
}
hasSandwich = kbText != "" && kernelText != ""
}
// Run each prompt in the lesson
generated := 0
for j, prompt := range lesson.Prompts {
var messages []ml.Message
if lesson.System != "" {
messages = append(messages, ml.Message{Role: "system", Content: lesson.System})
}
userContent := prompt.Prompt
if hasSandwich {
userContent = buildSandwich(kbText, prompt.Prompt, kernelText)
}
messages = append(messages, ml.Message{Role: "user", Content: userContent})
slog.Info("sequence: generating",
"lesson", lesson.ID,
"prompt", fmt.Sprintf("%d/%d", j+1, len(lesson.Prompts)),
"id", prompt.ID,
)
response, err := backend.Chat(cmd.Context(), messages, opts)
if err != nil {
slog.Error("sequence: generation failed",
"lesson", lesson.ID,
"prompt", prompt.ID,
"error", err,
)
continue
}
record := struct {
Messages []ml.Message `json:"messages"`
}{
Messages: []ml.Message{
{Role: "user", Content: userContent},
{Role: "assistant", Content: response},
},
}
if err := encoder.Encode(record); err != nil {
return fmt.Errorf("write record: %w", err)
}
generated++
totalGenerated++
}
// Mark lesson complete
state.Completed[lesson.ID] = true
state.UpdatedAt = time.Now().Format(time.RFC3339)
saveSequenceState(stateFile, state)
slog.Info("sequence: lesson complete",
"id", lesson.ID,
"generated", generated,
"total", len(lesson.Prompts),
)
}
state.Current = ""
state.UpdatedAt = time.Now().Format(time.RFC3339)
saveSequenceState(stateFile, state)
slog.Info("sequence: complete",
"id", seq.ID,
"output", sequenceOutput,
"total_generated", totalGenerated,
"lessons_completed", len(state.Completed),
"duration", time.Since(start).Round(time.Second),
)
return nil
}
func loadSequenceState(path string) sequenceState {
data, err := os.ReadFile(path)
if err != nil {
return sequenceState{}
}
var state sequenceState
json.Unmarshal(data, &state)
return state
}
func saveSequenceState(path string, state sequenceState) {
data, err := json.MarshalIndent(state, "", " ")
if err != nil {
return
}
os.WriteFile(path, data, 0644)
}