327 lines
8.1 KiB
Go
327 lines
8.1 KiB
Go
|
|
//go:build darwin && arm64
|
||
|
|
|
||
|
|
package ml
|
||
|
|
|
||
|
|
import (
|
||
|
|
"encoding/json"
|
||
|
|
"fmt"
|
||
|
|
"log/slog"
|
||
|
|
"os"
|
||
|
|
"path/filepath"
|
||
|
|
"strings"
|
||
|
|
"time"
|
||
|
|
|
||
|
|
"forge.lthn.ai/core/go-ai/ml"
|
||
|
|
"forge.lthn.ai/core/go/pkg/cli"
|
||
|
|
"gopkg.in/yaml.v3"
|
||
|
|
)
|
||
|
|
|
||
|
|
var sequenceCmd = &cli.Command{
|
||
|
|
Use: "sequence",
|
||
|
|
Short: "Run a training sequence of multiple lessons",
|
||
|
|
Long: `Runs an ordered sequence of lessons defined in a YAML file.
|
||
|
|
|
||
|
|
Sequence YAML format:
|
||
|
|
id: lek-full
|
||
|
|
title: "LEK Full Training Sequence"
|
||
|
|
mode: vertical
|
||
|
|
model-path: /path/to/model
|
||
|
|
lessons:
|
||
|
|
- sovereignty.yaml
|
||
|
|
- privacy.yaml
|
||
|
|
- censorship.yaml
|
||
|
|
|
||
|
|
Mode:
|
||
|
|
vertical Run lessons strictly in order (default)
|
||
|
|
horizontal Run all lessons, order doesn't matter
|
||
|
|
|
||
|
|
State is tracked per-sequence so runs can be resumed.`,
|
||
|
|
RunE: runSequence,
|
||
|
|
}
|
||
|
|
|
||
|
|
var (
|
||
|
|
sequenceFile string
|
||
|
|
sequenceModelPath string
|
||
|
|
sequenceOutput string
|
||
|
|
sequenceMaxTokens int
|
||
|
|
sequenceTemp float64
|
||
|
|
sequenceMemLimit int
|
||
|
|
)
|
||
|
|
|
||
|
|
func init() {
|
||
|
|
sequenceCmd.Flags().StringVar(&sequenceFile, "file", "", "Sequence YAML file (required)")
|
||
|
|
sequenceCmd.Flags().StringVar(&sequenceModelPath, "model-path", "", "Path to model directory (required)")
|
||
|
|
sequenceCmd.Flags().StringVar(&sequenceOutput, "output", "", "Output JSONL file (default: <sequence-id>.jsonl)")
|
||
|
|
sequenceCmd.Flags().IntVar(&sequenceMaxTokens, "max-tokens", 1024, "Max tokens per response")
|
||
|
|
sequenceCmd.Flags().Float64Var(&sequenceTemp, "temperature", 0.4, "Sampling temperature")
|
||
|
|
sequenceCmd.Flags().IntVar(&sequenceMemLimit, "memory-limit", 24, "Metal memory limit in GB")
|
||
|
|
sequenceCmd.MarkFlagRequired("file")
|
||
|
|
sequenceCmd.MarkFlagRequired("model-path")
|
||
|
|
}
|
||
|
|
|
||
|
|
// sequenceDef is a YAML sequence definition.
|
||
|
|
type sequenceDef struct {
|
||
|
|
ID string `yaml:"id"`
|
||
|
|
Title string `yaml:"title"`
|
||
|
|
Mode string `yaml:"mode"` // "vertical" (default) or "horizontal"
|
||
|
|
ModelPath string `yaml:"model-path"`
|
||
|
|
Lessons []string `yaml:"lessons"` // Relative paths to lesson YAML files
|
||
|
|
}
|
||
|
|
|
||
|
|
// sequenceState tracks progress through a sequence.
|
||
|
|
type sequenceState struct {
|
||
|
|
SequenceID string `json:"sequence_id"`
|
||
|
|
Completed map[string]bool `json:"completed"` // lesson ID → done
|
||
|
|
Current string `json:"current"`
|
||
|
|
UpdatedAt string `json:"updated_at"`
|
||
|
|
}
|
||
|
|
|
||
|
|
func runSequence(cmd *cli.Command, args []string) error {
|
||
|
|
start := time.Now()
|
||
|
|
|
||
|
|
// Load sequence YAML
|
||
|
|
data, err := os.ReadFile(sequenceFile)
|
||
|
|
if err != nil {
|
||
|
|
return fmt.Errorf("read sequence: %w", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
var seq sequenceDef
|
||
|
|
if err := yaml.Unmarshal(data, &seq); err != nil {
|
||
|
|
return fmt.Errorf("parse sequence: %w", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
if seq.ID == "" {
|
||
|
|
seq.ID = strings.TrimSuffix(filepath.Base(sequenceFile), filepath.Ext(sequenceFile))
|
||
|
|
}
|
||
|
|
if seq.Mode == "" {
|
||
|
|
seq.Mode = "vertical"
|
||
|
|
}
|
||
|
|
|
||
|
|
// Model path from sequence or flag
|
||
|
|
modelPath := sequenceModelPath
|
||
|
|
if modelPath == "" && seq.ModelPath != "" {
|
||
|
|
modelPath = seq.ModelPath
|
||
|
|
}
|
||
|
|
if modelPath == "" {
|
||
|
|
return fmt.Errorf("model-path is required (flag or sequence YAML)")
|
||
|
|
}
|
||
|
|
|
||
|
|
// Resolve output
|
||
|
|
if sequenceOutput == "" {
|
||
|
|
sequenceOutput = seq.ID + ".jsonl"
|
||
|
|
}
|
||
|
|
|
||
|
|
slog.Info("sequence: loaded",
|
||
|
|
"id", seq.ID,
|
||
|
|
"title", seq.Title,
|
||
|
|
"mode", seq.Mode,
|
||
|
|
"lessons", len(seq.Lessons),
|
||
|
|
)
|
||
|
|
|
||
|
|
// Load state
|
||
|
|
stateFile := seq.ID + ".sequence-state.json"
|
||
|
|
state := loadSequenceState(stateFile)
|
||
|
|
if state.SequenceID == "" {
|
||
|
|
state.SequenceID = seq.ID
|
||
|
|
state.Completed = make(map[string]bool)
|
||
|
|
}
|
||
|
|
|
||
|
|
// Load model once for all lessons
|
||
|
|
slog.Info("sequence: loading model", "path", modelPath)
|
||
|
|
backend, err := ml.NewMLXBackend(modelPath)
|
||
|
|
if err != nil {
|
||
|
|
return fmt.Errorf("load model: %w", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
opts := ml.GenOpts{
|
||
|
|
Temperature: sequenceTemp,
|
||
|
|
MaxTokens: sequenceMaxTokens,
|
||
|
|
}
|
||
|
|
|
||
|
|
// Open output file
|
||
|
|
outFile, err := os.OpenFile(sequenceOutput, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644)
|
||
|
|
if err != nil {
|
||
|
|
return fmt.Errorf("create output: %w", err)
|
||
|
|
}
|
||
|
|
defer outFile.Close()
|
||
|
|
encoder := json.NewEncoder(outFile)
|
||
|
|
|
||
|
|
baseDir := filepath.Dir(sequenceFile)
|
||
|
|
totalGenerated := 0
|
||
|
|
|
||
|
|
for i, lessonPath := range seq.Lessons {
|
||
|
|
// Resolve lesson path
|
||
|
|
if !filepath.IsAbs(lessonPath) {
|
||
|
|
lessonPath = filepath.Join(baseDir, lessonPath)
|
||
|
|
}
|
||
|
|
|
||
|
|
// Load lesson
|
||
|
|
lessonData, err := os.ReadFile(lessonPath)
|
||
|
|
if err != nil {
|
||
|
|
slog.Error("sequence: failed to read lesson",
|
||
|
|
"path", lessonPath,
|
||
|
|
"error", err,
|
||
|
|
)
|
||
|
|
if seq.Mode == "vertical" {
|
||
|
|
return fmt.Errorf("vertical sequence halted: %w", err)
|
||
|
|
}
|
||
|
|
continue
|
||
|
|
}
|
||
|
|
|
||
|
|
var lesson lessonDef
|
||
|
|
if err := yaml.Unmarshal(lessonData, &lesson); err != nil {
|
||
|
|
slog.Error("sequence: failed to parse lesson",
|
||
|
|
"path", lessonPath,
|
||
|
|
"error", err,
|
||
|
|
)
|
||
|
|
if seq.Mode == "vertical" {
|
||
|
|
return fmt.Errorf("vertical sequence halted: %w", err)
|
||
|
|
}
|
||
|
|
continue
|
||
|
|
}
|
||
|
|
|
||
|
|
if lesson.ID == "" {
|
||
|
|
lesson.ID = strings.TrimSuffix(filepath.Base(lessonPath), filepath.Ext(lessonPath))
|
||
|
|
}
|
||
|
|
|
||
|
|
// Skip completed lessons
|
||
|
|
if state.Completed[lesson.ID] {
|
||
|
|
slog.Info("sequence: skipping completed lesson",
|
||
|
|
"lesson", fmt.Sprintf("%d/%d", i+1, len(seq.Lessons)),
|
||
|
|
"id", lesson.ID,
|
||
|
|
)
|
||
|
|
continue
|
||
|
|
}
|
||
|
|
|
||
|
|
state.Current = lesson.ID
|
||
|
|
|
||
|
|
slog.Info("sequence: starting lesson",
|
||
|
|
"lesson", fmt.Sprintf("%d/%d", i+1, len(seq.Lessons)),
|
||
|
|
"id", lesson.ID,
|
||
|
|
"title", lesson.Title,
|
||
|
|
"prompts", len(lesson.Prompts),
|
||
|
|
)
|
||
|
|
|
||
|
|
// Load sandwich files for this lesson
|
||
|
|
var kbText, kernelText string
|
||
|
|
hasSandwich := false
|
||
|
|
if lesson.Sandwich != nil {
|
||
|
|
lessonDir := filepath.Dir(lessonPath)
|
||
|
|
if lesson.Sandwich.KB != "" {
|
||
|
|
kbPath := lesson.Sandwich.KB
|
||
|
|
if !filepath.IsAbs(kbPath) {
|
||
|
|
kbPath = filepath.Join(lessonDir, kbPath)
|
||
|
|
}
|
||
|
|
d, err := os.ReadFile(kbPath)
|
||
|
|
if err != nil {
|
||
|
|
slog.Error("sequence: failed to read KB", "error", err)
|
||
|
|
} else {
|
||
|
|
kbText = string(d)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
if lesson.Sandwich.Kernel != "" {
|
||
|
|
kernelPath := lesson.Sandwich.Kernel
|
||
|
|
if !filepath.IsAbs(kernelPath) {
|
||
|
|
kernelPath = filepath.Join(lessonDir, kernelPath)
|
||
|
|
}
|
||
|
|
d, err := os.ReadFile(kernelPath)
|
||
|
|
if err != nil {
|
||
|
|
slog.Error("sequence: failed to read kernel", "error", err)
|
||
|
|
} else {
|
||
|
|
kernelText = string(d)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
hasSandwich = kbText != "" && kernelText != ""
|
||
|
|
}
|
||
|
|
|
||
|
|
// Run each prompt in the lesson
|
||
|
|
generated := 0
|
||
|
|
for j, prompt := range lesson.Prompts {
|
||
|
|
var messages []ml.Message
|
||
|
|
if lesson.System != "" {
|
||
|
|
messages = append(messages, ml.Message{Role: "system", Content: lesson.System})
|
||
|
|
}
|
||
|
|
|
||
|
|
userContent := prompt.Prompt
|
||
|
|
if hasSandwich {
|
||
|
|
userContent = buildSandwich(kbText, prompt.Prompt, kernelText)
|
||
|
|
}
|
||
|
|
messages = append(messages, ml.Message{Role: "user", Content: userContent})
|
||
|
|
|
||
|
|
slog.Info("sequence: generating",
|
||
|
|
"lesson", lesson.ID,
|
||
|
|
"prompt", fmt.Sprintf("%d/%d", j+1, len(lesson.Prompts)),
|
||
|
|
"id", prompt.ID,
|
||
|
|
)
|
||
|
|
|
||
|
|
response, err := backend.Chat(cmd.Context(), messages, opts)
|
||
|
|
if err != nil {
|
||
|
|
slog.Error("sequence: generation failed",
|
||
|
|
"lesson", lesson.ID,
|
||
|
|
"prompt", prompt.ID,
|
||
|
|
"error", err,
|
||
|
|
)
|
||
|
|
continue
|
||
|
|
}
|
||
|
|
|
||
|
|
record := struct {
|
||
|
|
Messages []ml.Message `json:"messages"`
|
||
|
|
}{
|
||
|
|
Messages: []ml.Message{
|
||
|
|
{Role: "user", Content: userContent},
|
||
|
|
{Role: "assistant", Content: response},
|
||
|
|
},
|
||
|
|
}
|
||
|
|
if err := encoder.Encode(record); err != nil {
|
||
|
|
return fmt.Errorf("write record: %w", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
generated++
|
||
|
|
totalGenerated++
|
||
|
|
}
|
||
|
|
|
||
|
|
// Mark lesson complete
|
||
|
|
state.Completed[lesson.ID] = true
|
||
|
|
state.UpdatedAt = time.Now().Format(time.RFC3339)
|
||
|
|
saveSequenceState(stateFile, state)
|
||
|
|
|
||
|
|
slog.Info("sequence: lesson complete",
|
||
|
|
"id", lesson.ID,
|
||
|
|
"generated", generated,
|
||
|
|
"total", len(lesson.Prompts),
|
||
|
|
)
|
||
|
|
}
|
||
|
|
|
||
|
|
state.Current = ""
|
||
|
|
state.UpdatedAt = time.Now().Format(time.RFC3339)
|
||
|
|
saveSequenceState(stateFile, state)
|
||
|
|
|
||
|
|
slog.Info("sequence: complete",
|
||
|
|
"id", seq.ID,
|
||
|
|
"output", sequenceOutput,
|
||
|
|
"total_generated", totalGenerated,
|
||
|
|
"lessons_completed", len(state.Completed),
|
||
|
|
"duration", time.Since(start).Round(time.Second),
|
||
|
|
)
|
||
|
|
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func loadSequenceState(path string) sequenceState {
|
||
|
|
data, err := os.ReadFile(path)
|
||
|
|
if err != nil {
|
||
|
|
return sequenceState{}
|
||
|
|
}
|
||
|
|
var state sequenceState
|
||
|
|
json.Unmarshal(data, &state)
|
||
|
|
return state
|
||
|
|
}
|
||
|
|
|
||
|
|
func saveSequenceState(path string, state sequenceState) {
|
||
|
|
data, err := json.MarshalIndent(state, "", " ")
|
||
|
|
if err != nil {
|
||
|
|
return
|
||
|
|
}
|
||
|
|
os.WriteFile(path, data, 0644)
|
||
|
|
}
|