cli/cmd/ml/cmd_sequence.go

327 lines
8.1 KiB
Go
Raw Normal View History

//go:build darwin && arm64
package ml
import (
"encoding/json"
"fmt"
"log/slog"
"os"
"path/filepath"
"strings"
"time"
"forge.lthn.ai/core/go-ai/ml"
"forge.lthn.ai/core/go/pkg/cli"
"gopkg.in/yaml.v3"
)
var sequenceCmd = &cli.Command{
Use: "sequence",
Short: "Run a training sequence of multiple lessons",
Long: `Runs an ordered sequence of lessons defined in a YAML file.
Sequence YAML format:
id: lek-full
title: "LEK Full Training Sequence"
mode: vertical
model-path: /path/to/model
lessons:
- sovereignty.yaml
- privacy.yaml
- censorship.yaml
Mode:
vertical Run lessons strictly in order (default)
horizontal Run all lessons, order doesn't matter
State is tracked per-sequence so runs can be resumed.`,
RunE: runSequence,
}
var (
sequenceFile string
sequenceModelPath string
sequenceOutput string
sequenceMaxTokens int
sequenceTemp float64
sequenceMemLimit int
)
func init() {
sequenceCmd.Flags().StringVar(&sequenceFile, "file", "", "Sequence YAML file (required)")
sequenceCmd.Flags().StringVar(&sequenceModelPath, "model-path", "", "Path to model directory (required)")
sequenceCmd.Flags().StringVar(&sequenceOutput, "output", "", "Output JSONL file (default: <sequence-id>.jsonl)")
sequenceCmd.Flags().IntVar(&sequenceMaxTokens, "max-tokens", 1024, "Max tokens per response")
sequenceCmd.Flags().Float64Var(&sequenceTemp, "temperature", 0.4, "Sampling temperature")
sequenceCmd.Flags().IntVar(&sequenceMemLimit, "memory-limit", 24, "Metal memory limit in GB")
sequenceCmd.MarkFlagRequired("file")
sequenceCmd.MarkFlagRequired("model-path")
}
// sequenceDef is a YAML sequence definition.
type sequenceDef struct {
ID string `yaml:"id"`
Title string `yaml:"title"`
Mode string `yaml:"mode"` // "vertical" (default) or "horizontal"
ModelPath string `yaml:"model-path"`
Lessons []string `yaml:"lessons"` // Relative paths to lesson YAML files
}
// sequenceState tracks progress through a sequence.
type sequenceState struct {
SequenceID string `json:"sequence_id"`
Completed map[string]bool `json:"completed"` // lesson ID → done
Current string `json:"current"`
UpdatedAt string `json:"updated_at"`
}
func runSequence(cmd *cli.Command, args []string) error {
start := time.Now()
// Load sequence YAML
data, err := os.ReadFile(sequenceFile)
if err != nil {
return fmt.Errorf("read sequence: %w", err)
}
var seq sequenceDef
if err := yaml.Unmarshal(data, &seq); err != nil {
return fmt.Errorf("parse sequence: %w", err)
}
if seq.ID == "" {
seq.ID = strings.TrimSuffix(filepath.Base(sequenceFile), filepath.Ext(sequenceFile))
}
if seq.Mode == "" {
seq.Mode = "vertical"
}
// Model path from sequence or flag
modelPath := sequenceModelPath
if modelPath == "" && seq.ModelPath != "" {
modelPath = seq.ModelPath
}
if modelPath == "" {
return fmt.Errorf("model-path is required (flag or sequence YAML)")
}
// Resolve output
if sequenceOutput == "" {
sequenceOutput = seq.ID + ".jsonl"
}
slog.Info("sequence: loaded",
"id", seq.ID,
"title", seq.Title,
"mode", seq.Mode,
"lessons", len(seq.Lessons),
)
// Load state
stateFile := seq.ID + ".sequence-state.json"
state := loadSequenceState(stateFile)
if state.SequenceID == "" {
state.SequenceID = seq.ID
state.Completed = make(map[string]bool)
}
// Load model once for all lessons
slog.Info("sequence: loading model", "path", modelPath)
backend, err := ml.NewMLXBackend(modelPath)
if err != nil {
return fmt.Errorf("load model: %w", err)
}
opts := ml.GenOpts{
Temperature: sequenceTemp,
MaxTokens: sequenceMaxTokens,
}
// Open output file
outFile, err := os.OpenFile(sequenceOutput, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644)
if err != nil {
return fmt.Errorf("create output: %w", err)
}
defer outFile.Close()
encoder := json.NewEncoder(outFile)
baseDir := filepath.Dir(sequenceFile)
totalGenerated := 0
for i, lessonPath := range seq.Lessons {
// Resolve lesson path
if !filepath.IsAbs(lessonPath) {
lessonPath = filepath.Join(baseDir, lessonPath)
}
// Load lesson
lessonData, err := os.ReadFile(lessonPath)
if err != nil {
slog.Error("sequence: failed to read lesson",
"path", lessonPath,
"error", err,
)
if seq.Mode == "vertical" {
return fmt.Errorf("vertical sequence halted: %w", err)
}
continue
}
var lesson lessonDef
if err := yaml.Unmarshal(lessonData, &lesson); err != nil {
slog.Error("sequence: failed to parse lesson",
"path", lessonPath,
"error", err,
)
if seq.Mode == "vertical" {
return fmt.Errorf("vertical sequence halted: %w", err)
}
continue
}
if lesson.ID == "" {
lesson.ID = strings.TrimSuffix(filepath.Base(lessonPath), filepath.Ext(lessonPath))
}
// Skip completed lessons
if state.Completed[lesson.ID] {
slog.Info("sequence: skipping completed lesson",
"lesson", fmt.Sprintf("%d/%d", i+1, len(seq.Lessons)),
"id", lesson.ID,
)
continue
}
state.Current = lesson.ID
slog.Info("sequence: starting lesson",
"lesson", fmt.Sprintf("%d/%d", i+1, len(seq.Lessons)),
"id", lesson.ID,
"title", lesson.Title,
"prompts", len(lesson.Prompts),
)
// Load sandwich files for this lesson
var kbText, kernelText string
hasSandwich := false
if lesson.Sandwich != nil {
lessonDir := filepath.Dir(lessonPath)
if lesson.Sandwich.KB != "" {
kbPath := lesson.Sandwich.KB
if !filepath.IsAbs(kbPath) {
kbPath = filepath.Join(lessonDir, kbPath)
}
d, err := os.ReadFile(kbPath)
if err != nil {
slog.Error("sequence: failed to read KB", "error", err)
} else {
kbText = string(d)
}
}
if lesson.Sandwich.Kernel != "" {
kernelPath := lesson.Sandwich.Kernel
if !filepath.IsAbs(kernelPath) {
kernelPath = filepath.Join(lessonDir, kernelPath)
}
d, err := os.ReadFile(kernelPath)
if err != nil {
slog.Error("sequence: failed to read kernel", "error", err)
} else {
kernelText = string(d)
}
}
hasSandwich = kbText != "" && kernelText != ""
}
// Run each prompt in the lesson
generated := 0
for j, prompt := range lesson.Prompts {
var messages []ml.Message
if lesson.System != "" {
messages = append(messages, ml.Message{Role: "system", Content: lesson.System})
}
userContent := prompt.Prompt
if hasSandwich {
userContent = buildSandwich(kbText, prompt.Prompt, kernelText)
}
messages = append(messages, ml.Message{Role: "user", Content: userContent})
slog.Info("sequence: generating",
"lesson", lesson.ID,
"prompt", fmt.Sprintf("%d/%d", j+1, len(lesson.Prompts)),
"id", prompt.ID,
)
response, err := backend.Chat(cmd.Context(), messages, opts)
if err != nil {
slog.Error("sequence: generation failed",
"lesson", lesson.ID,
"prompt", prompt.ID,
"error", err,
)
continue
}
record := struct {
Messages []ml.Message `json:"messages"`
}{
Messages: []ml.Message{
{Role: "user", Content: userContent},
{Role: "assistant", Content: response},
},
}
if err := encoder.Encode(record); err != nil {
return fmt.Errorf("write record: %w", err)
}
generated++
totalGenerated++
}
// Mark lesson complete
state.Completed[lesson.ID] = true
state.UpdatedAt = time.Now().Format(time.RFC3339)
saveSequenceState(stateFile, state)
slog.Info("sequence: lesson complete",
"id", lesson.ID,
"generated", generated,
"total", len(lesson.Prompts),
)
}
state.Current = ""
state.UpdatedAt = time.Now().Format(time.RFC3339)
saveSequenceState(stateFile, state)
slog.Info("sequence: complete",
"id", seq.ID,
"output", sequenceOutput,
"total_generated", totalGenerated,
"lessons_completed", len(state.Completed),
"duration", time.Since(start).Round(time.Second),
)
return nil
}
func loadSequenceState(path string) sequenceState {
data, err := os.ReadFile(path)
if err != nil {
return sequenceState{}
}
var state sequenceState
json.Unmarshal(data, &state)
return state
}
func saveSequenceState(path string, state sequenceState) {
data, err := json.MarshalIndent(state, "", " ")
if err != nil {
return
}
os.WriteFile(path, data, 0644)
}