cli/cmd/ml/cmd_lesson.go

341 lines
8.5 KiB
Go
Raw Permalink Normal View History

//go:build darwin && arm64
package ml
import (
"context"
"encoding/json"
"fmt"
"log/slog"
"os"
"path/filepath"
"runtime"
"strings"
"time"
"forge.lthn.ai/core/go-ai/ml"
"forge.lthn.ai/core/go/pkg/cli"
"gopkg.in/yaml.v3"
)
var lessonCmd = &cli.Command{
Use: "lesson",
Short: "Run a structured training lesson from a YAML definition",
Long: `Runs a training lesson defined in a YAML file. Each lesson contains
prompts organised by category, optional system prompt, and sandwich
signing configuration.
Lesson YAML format:
id: lek-sovereignty
title: "Sovereignty Lessons"
system: "You are a helpful assistant."
sandwich:
kb: path/to/axioms.md
kernel: path/to/kernel.txt
prompts:
- id: P01
category: sovereignty
prompt: "A user wants to build an auth system."
signal: "Does the model prefer decentralised?"
The command generates responses for each prompt and writes them as
training JSONL. State is tracked so lessons can be resumed.`,
RunE: runLesson,
}
var (
lessonFile string
lessonModelPath string
lessonOutput string
lessonMaxTokens int
lessonTemp float64
lessonMemLimit int
lessonResume bool
lessonInteract bool
)
func init() {
lessonCmd.Flags().StringVar(&lessonFile, "file", "", "Lesson YAML file (required)")
lessonCmd.Flags().StringVar(&lessonModelPath, "model-path", "", "Path to model directory (required)")
lessonCmd.Flags().StringVar(&lessonOutput, "output", "", "Output JSONL file (default: <lesson-id>.jsonl)")
lessonCmd.Flags().IntVar(&lessonMaxTokens, "max-tokens", 1024, "Max tokens per response")
lessonCmd.Flags().Float64Var(&lessonTemp, "temperature", 0.4, "Sampling temperature")
lessonCmd.Flags().IntVar(&lessonMemLimit, "memory-limit", 24, "Metal memory limit in GB")
lessonCmd.Flags().BoolVar(&lessonResume, "resume", true, "Resume from last completed prompt")
lessonCmd.Flags().BoolVar(&lessonInteract, "interactive", false, "Interactive mode: review each response before continuing")
lessonCmd.MarkFlagRequired("file")
lessonCmd.MarkFlagRequired("model-path")
}
// lessonDef is a YAML lesson definition.
type lessonDef struct {
ID string `yaml:"id"`
Title string `yaml:"title"`
System string `yaml:"system"`
Sandwich *lessonSandwichCfg `yaml:"sandwich"`
Prompts []lessonPrompt `yaml:"prompts"`
}
type lessonSandwichCfg struct {
KB string `yaml:"kb"`
Kernel string `yaml:"kernel"`
}
type lessonPrompt struct {
ID string `yaml:"id"`
Category string `yaml:"category"`
Prompt string `yaml:"prompt"`
Signal string `yaml:"signal"`
}
// lessonState tracks progress through a lesson.
type lessonState struct {
LessonID string `json:"lesson_id"`
Completed map[string]lessonResult `json:"completed"`
UpdatedAt string `json:"updated_at"`
}
type lessonResult struct {
ResponseChars int `json:"response_chars"`
Duration string `json:"duration"`
CompletedAt string `json:"completed_at"`
}
func runLesson(cmd *cli.Command, args []string) error {
start := time.Now()
// Load lesson YAML
data, err := os.ReadFile(lessonFile)
if err != nil {
return fmt.Errorf("read lesson: %w", err)
}
var lesson lessonDef
if err := yaml.Unmarshal(data, &lesson); err != nil {
return fmt.Errorf("parse lesson: %w", err)
}
if lesson.ID == "" {
lesson.ID = strings.TrimSuffix(filepath.Base(lessonFile), filepath.Ext(lessonFile))
}
// Resolve output path
if lessonOutput == "" {
lessonOutput = lesson.ID + ".jsonl"
}
// Load sandwich files if configured
var kbText, kernelText string
sandwich := false
if lesson.Sandwich != nil {
baseDir := filepath.Dir(lessonFile)
if lesson.Sandwich.KB != "" {
kbPath := lesson.Sandwich.KB
if !filepath.IsAbs(kbPath) {
kbPath = filepath.Join(baseDir, kbPath)
}
d, err := os.ReadFile(kbPath)
if err != nil {
return fmt.Errorf("read KB: %w", err)
}
kbText = string(d)
}
if lesson.Sandwich.Kernel != "" {
kernelPath := lesson.Sandwich.Kernel
if !filepath.IsAbs(kernelPath) {
kernelPath = filepath.Join(baseDir, kernelPath)
}
d, err := os.ReadFile(kernelPath)
if err != nil {
return fmt.Errorf("read kernel: %w", err)
}
kernelText = string(d)
}
sandwich = kbText != "" && kernelText != ""
}
slog.Info("lesson: loaded",
"id", lesson.ID,
"title", lesson.Title,
"prompts", len(lesson.Prompts),
"sandwich", sandwich,
)
if len(lesson.Prompts) == 0 {
return fmt.Errorf("lesson has no prompts")
}
// Load state for resume
stateFile := lesson.ID + ".state.json"
state := loadLessonState(stateFile)
if state.LessonID == "" {
state.LessonID = lesson.ID
state.Completed = make(map[string]lessonResult)
}
// Count remaining
var remaining []lessonPrompt
for _, p := range lesson.Prompts {
if lessonResume {
if _, done := state.Completed[p.ID]; done {
continue
}
}
remaining = append(remaining, p)
}
if len(remaining) == 0 {
slog.Info("lesson: all prompts completed",
"id", lesson.ID,
"total", len(lesson.Prompts),
)
return nil
}
slog.Info("lesson: starting",
"remaining", len(remaining),
"completed", len(state.Completed),
"total", len(lesson.Prompts),
)
// Load model
slog.Info("lesson: loading model", "path", lessonModelPath)
backend, err := ml.NewMLXBackend(lessonModelPath)
if err != nil {
return fmt.Errorf("load model: %w", err)
}
opts := ml.GenOpts{
Temperature: lessonTemp,
MaxTokens: lessonMaxTokens,
}
// Open output file (append mode for resume)
outFile, err := os.OpenFile(lessonOutput, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644)
if err != nil {
return fmt.Errorf("create output: %w", err)
}
defer outFile.Close()
encoder := json.NewEncoder(outFile)
generated := 0
for i, prompt := range remaining {
promptStart := time.Now()
slog.Info("lesson: generating",
"prompt", fmt.Sprintf("%d/%d", i+1, len(remaining)),
"id", prompt.ID,
"category", prompt.Category,
)
// Build messages
var messages []ml.Message
if lesson.System != "" {
messages = append(messages, ml.Message{Role: "system", Content: lesson.System})
}
userContent := prompt.Prompt
if sandwich {
userContent = buildSandwich(kbText, prompt.Prompt, kernelText)
}
messages = append(messages, ml.Message{Role: "user", Content: userContent})
// Generate
response, err := backend.Chat(context.Background(), messages, opts)
if err != nil {
slog.Error("lesson: generation failed",
"id", prompt.ID,
"error", err,
)
continue
}
elapsed := time.Since(promptStart)
// Write training record
record := struct {
Messages []ml.Message `json:"messages"`
}{
Messages: []ml.Message{
{Role: "user", Content: userContent},
{Role: "assistant", Content: response},
},
}
if err := encoder.Encode(record); err != nil {
return fmt.Errorf("write record: %w", err)
}
// Update state
state.Completed[prompt.ID] = lessonResult{
ResponseChars: len(response),
Duration: elapsed.Round(time.Second).String(),
CompletedAt: time.Now().Format(time.RFC3339),
}
state.UpdatedAt = time.Now().Format(time.RFC3339)
if err := saveLessonState(stateFile, state); err != nil {
slog.Warn("lesson: failed to save state", "error", err)
}
generated++
slog.Info("lesson: generated",
"id", prompt.ID,
"category", prompt.Category,
"response_chars", len(response),
"duration", elapsed.Round(time.Second),
)
// Interactive mode: show response and wait for confirmation
if lessonInteract {
fmt.Printf("\n--- %s (%s) ---\n", prompt.ID, prompt.Category)
fmt.Printf("Prompt: %s\n\n", prompt.Prompt)
if prompt.Signal != "" {
fmt.Printf("Signal: %s\n\n", prompt.Signal)
}
fmt.Printf("Response:\n%s\n", response)
fmt.Printf("\nPress Enter to continue (or 'q' to stop)... ")
var input string
fmt.Scanln(&input)
if strings.TrimSpace(input) == "q" {
break
}
}
// Periodic cleanup
if (i+1)%4 == 0 {
runtime.GC()
}
}
slog.Info("lesson: complete",
"id", lesson.ID,
"output", lessonOutput,
"generated", generated,
"total_completed", len(state.Completed),
"total_prompts", len(lesson.Prompts),
"duration", time.Since(start).Round(time.Second),
)
return nil
}
func loadLessonState(path string) lessonState {
data, err := os.ReadFile(path)
if err != nil {
return lessonState{}
}
var state lessonState
json.Unmarshal(data, &state)
return state
}
func saveLessonState(path string, state lessonState) error {
data, err := json.MarshalIndent(state, "", " ")
if err != nil {
return err
}
return os.WriteFile(path, data, 0644)
}