//go:build darwin && arm64 package ml import ( "encoding/json" "fmt" "log/slog" "os" "path/filepath" "strings" "time" "forge.lthn.ai/core/go-ai/ml" "forge.lthn.ai/core/go/pkg/cli" "gopkg.in/yaml.v3" ) var sequenceCmd = &cli.Command{ Use: "sequence", Short: "Run a training sequence of multiple lessons", Long: `Runs an ordered sequence of lessons defined in a YAML file. Sequence YAML format: id: lek-full title: "LEK Full Training Sequence" mode: vertical model-path: /path/to/model lessons: - sovereignty.yaml - privacy.yaml - censorship.yaml Mode: vertical Run lessons strictly in order (default) horizontal Run all lessons, order doesn't matter State is tracked per-sequence so runs can be resumed.`, RunE: runSequence, } var ( sequenceFile string sequenceModelPath string sequenceOutput string sequenceMaxTokens int sequenceTemp float64 sequenceMemLimit int ) func init() { sequenceCmd.Flags().StringVar(&sequenceFile, "file", "", "Sequence YAML file (required)") sequenceCmd.Flags().StringVar(&sequenceModelPath, "model-path", "", "Path to model directory (required)") sequenceCmd.Flags().StringVar(&sequenceOutput, "output", "", "Output JSONL file (default: .jsonl)") sequenceCmd.Flags().IntVar(&sequenceMaxTokens, "max-tokens", 1024, "Max tokens per response") sequenceCmd.Flags().Float64Var(&sequenceTemp, "temperature", 0.4, "Sampling temperature") sequenceCmd.Flags().IntVar(&sequenceMemLimit, "memory-limit", 24, "Metal memory limit in GB") sequenceCmd.MarkFlagRequired("file") sequenceCmd.MarkFlagRequired("model-path") } // sequenceDef is a YAML sequence definition. type sequenceDef struct { ID string `yaml:"id"` Title string `yaml:"title"` Mode string `yaml:"mode"` // "vertical" (default) or "horizontal" ModelPath string `yaml:"model-path"` Lessons []string `yaml:"lessons"` // Relative paths to lesson YAML files } // sequenceState tracks progress through a sequence. type sequenceState struct { SequenceID string `json:"sequence_id"` Completed map[string]bool `json:"completed"` // lesson ID → done Current string `json:"current"` UpdatedAt string `json:"updated_at"` } func runSequence(cmd *cli.Command, args []string) error { start := time.Now() // Load sequence YAML data, err := os.ReadFile(sequenceFile) if err != nil { return fmt.Errorf("read sequence: %w", err) } var seq sequenceDef if err := yaml.Unmarshal(data, &seq); err != nil { return fmt.Errorf("parse sequence: %w", err) } if seq.ID == "" { seq.ID = strings.TrimSuffix(filepath.Base(sequenceFile), filepath.Ext(sequenceFile)) } if seq.Mode == "" { seq.Mode = "vertical" } // Model path from sequence or flag modelPath := sequenceModelPath if modelPath == "" && seq.ModelPath != "" { modelPath = seq.ModelPath } if modelPath == "" { return fmt.Errorf("model-path is required (flag or sequence YAML)") } // Resolve output if sequenceOutput == "" { sequenceOutput = seq.ID + ".jsonl" } slog.Info("sequence: loaded", "id", seq.ID, "title", seq.Title, "mode", seq.Mode, "lessons", len(seq.Lessons), ) // Load state stateFile := seq.ID + ".sequence-state.json" state := loadSequenceState(stateFile) if state.SequenceID == "" { state.SequenceID = seq.ID state.Completed = make(map[string]bool) } // Load model once for all lessons slog.Info("sequence: loading model", "path", modelPath) backend, err := ml.NewMLXBackend(modelPath) if err != nil { return fmt.Errorf("load model: %w", err) } opts := ml.GenOpts{ Temperature: sequenceTemp, MaxTokens: sequenceMaxTokens, } // Open output file outFile, err := os.OpenFile(sequenceOutput, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644) if err != nil { return fmt.Errorf("create output: %w", err) } defer outFile.Close() encoder := json.NewEncoder(outFile) baseDir := filepath.Dir(sequenceFile) totalGenerated := 0 for i, lessonPath := range seq.Lessons { // Resolve lesson path if !filepath.IsAbs(lessonPath) { lessonPath = filepath.Join(baseDir, lessonPath) } // Load lesson lessonData, err := os.ReadFile(lessonPath) if err != nil { slog.Error("sequence: failed to read lesson", "path", lessonPath, "error", err, ) if seq.Mode == "vertical" { return fmt.Errorf("vertical sequence halted: %w", err) } continue } var lesson lessonDef if err := yaml.Unmarshal(lessonData, &lesson); err != nil { slog.Error("sequence: failed to parse lesson", "path", lessonPath, "error", err, ) if seq.Mode == "vertical" { return fmt.Errorf("vertical sequence halted: %w", err) } continue } if lesson.ID == "" { lesson.ID = strings.TrimSuffix(filepath.Base(lessonPath), filepath.Ext(lessonPath)) } // Skip completed lessons if state.Completed[lesson.ID] { slog.Info("sequence: skipping completed lesson", "lesson", fmt.Sprintf("%d/%d", i+1, len(seq.Lessons)), "id", lesson.ID, ) continue } state.Current = lesson.ID slog.Info("sequence: starting lesson", "lesson", fmt.Sprintf("%d/%d", i+1, len(seq.Lessons)), "id", lesson.ID, "title", lesson.Title, "prompts", len(lesson.Prompts), ) // Load sandwich files for this lesson var kbText, kernelText string hasSandwich := false if lesson.Sandwich != nil { lessonDir := filepath.Dir(lessonPath) if lesson.Sandwich.KB != "" { kbPath := lesson.Sandwich.KB if !filepath.IsAbs(kbPath) { kbPath = filepath.Join(lessonDir, kbPath) } d, err := os.ReadFile(kbPath) if err != nil { slog.Error("sequence: failed to read KB", "error", err) } else { kbText = string(d) } } if lesson.Sandwich.Kernel != "" { kernelPath := lesson.Sandwich.Kernel if !filepath.IsAbs(kernelPath) { kernelPath = filepath.Join(lessonDir, kernelPath) } d, err := os.ReadFile(kernelPath) if err != nil { slog.Error("sequence: failed to read kernel", "error", err) } else { kernelText = string(d) } } hasSandwich = kbText != "" && kernelText != "" } // Run each prompt in the lesson generated := 0 for j, prompt := range lesson.Prompts { var messages []ml.Message if lesson.System != "" { messages = append(messages, ml.Message{Role: "system", Content: lesson.System}) } userContent := prompt.Prompt if hasSandwich { userContent = buildSandwich(kbText, prompt.Prompt, kernelText) } messages = append(messages, ml.Message{Role: "user", Content: userContent}) slog.Info("sequence: generating", "lesson", lesson.ID, "prompt", fmt.Sprintf("%d/%d", j+1, len(lesson.Prompts)), "id", prompt.ID, ) response, err := backend.Chat(cmd.Context(), messages, opts) if err != nil { slog.Error("sequence: generation failed", "lesson", lesson.ID, "prompt", prompt.ID, "error", err, ) continue } record := struct { Messages []ml.Message `json:"messages"` }{ Messages: []ml.Message{ {Role: "user", Content: userContent}, {Role: "assistant", Content: response}, }, } if err := encoder.Encode(record); err != nil { return fmt.Errorf("write record: %w", err) } generated++ totalGenerated++ } // Mark lesson complete state.Completed[lesson.ID] = true state.UpdatedAt = time.Now().Format(time.RFC3339) saveSequenceState(stateFile, state) slog.Info("sequence: lesson complete", "id", lesson.ID, "generated", generated, "total", len(lesson.Prompts), ) } state.Current = "" state.UpdatedAt = time.Now().Format(time.RFC3339) saveSequenceState(stateFile, state) slog.Info("sequence: complete", "id", seq.ID, "output", sequenceOutput, "total_generated", totalGenerated, "lessons_completed", len(state.Completed), "duration", time.Since(start).Round(time.Second), ) return nil } func loadSequenceState(path string) sequenceState { data, err := os.ReadFile(path) if err != nil { return sequenceState{} } var state sequenceState json.Unmarshal(data, &state) return state } func saveSequenceState(path string, state sequenceState) { data, err := json.MarshalIndent(state, "", " ") if err != nil { return } os.WriteFile(path, data, 0644) }