328 lines
8.5 KiB
Go
328 lines
8.5 KiB
Go
|
|
//go:build darwin && arm64
|
||
|
|
|
||
|
|
package ml
|
||
|
|
|
||
|
|
import (
|
||
|
|
"bufio"
|
||
|
|
"encoding/json"
|
||
|
|
"fmt"
|
||
|
|
"log/slog"
|
||
|
|
"os"
|
||
|
|
"runtime"
|
||
|
|
"strings"
|
||
|
|
"time"
|
||
|
|
|
||
|
|
"forge.lthn.ai/core/go-ai/ml"
|
||
|
|
"forge.lthn.ai/core/go/pkg/cli"
|
||
|
|
)
|
||
|
|
|
||
|
|
var chatCmd = &cli.Command{
|
||
|
|
Use: "chat",
|
||
|
|
Short: "Interactive conversation with a local MLX model",
|
||
|
|
Long: `Start an interactive chat session with a local MLX model.
|
||
|
|
|
||
|
|
All exchanges are captured and can be written to training JSONL on exit
|
||
|
|
for use with 'core ml train'. Optionally apply axiom sandwich signing
|
||
|
|
to wrap the conversation for LEK training.
|
||
|
|
|
||
|
|
Commands during chat:
|
||
|
|
/quit, /exit End session and save
|
||
|
|
/save Save conversation so far (appends to output)
|
||
|
|
/clear Clear conversation history
|
||
|
|
/system <text> Set system prompt
|
||
|
|
/undo Remove last exchange`,
|
||
|
|
RunE: runChat,
|
||
|
|
}
|
||
|
|
|
||
|
|
var (
|
||
|
|
chatModelPath string
|
||
|
|
chatOutput string
|
||
|
|
chatKB string
|
||
|
|
chatKernel string
|
||
|
|
chatSystem string
|
||
|
|
chatMaxTokens int
|
||
|
|
chatTemp float64
|
||
|
|
chatMemLimit int
|
||
|
|
)
|
||
|
|
|
||
|
|
func init() {
|
||
|
|
chatCmd.Flags().StringVar(&chatModelPath, "model-path", "", "Path to model directory (required)")
|
||
|
|
chatCmd.Flags().StringVar(&chatOutput, "output", "", "Output JSONL file for captured conversation")
|
||
|
|
chatCmd.Flags().StringVar(&chatKB, "kb", "", "Knowledge base document for sandwich signing")
|
||
|
|
chatCmd.Flags().StringVar(&chatKernel, "kernel", "", "LEK-1 kernel file for sandwich signing")
|
||
|
|
chatCmd.Flags().StringVar(&chatSystem, "system", "", "Initial system prompt")
|
||
|
|
chatCmd.Flags().IntVar(&chatMaxTokens, "max-tokens", 2048, "Max tokens per response")
|
||
|
|
chatCmd.Flags().Float64Var(&chatTemp, "temperature", 0.4, "Sampling temperature")
|
||
|
|
chatCmd.Flags().IntVar(&chatMemLimit, "memory-limit", 24, "Metal memory limit in GB")
|
||
|
|
chatCmd.MarkFlagRequired("model-path")
|
||
|
|
}
|
||
|
|
|
||
|
|
func runChat(cmd *cli.Command, args []string) error {
|
||
|
|
// Load optional KB and kernel for sandwich signing
|
||
|
|
var kbText, kernelText string
|
||
|
|
if chatKB != "" {
|
||
|
|
data, err := os.ReadFile(chatKB)
|
||
|
|
if err != nil {
|
||
|
|
return fmt.Errorf("read KB: %w", err)
|
||
|
|
}
|
||
|
|
kbText = string(data)
|
||
|
|
}
|
||
|
|
if chatKernel != "" {
|
||
|
|
data, err := os.ReadFile(chatKernel)
|
||
|
|
if err != nil {
|
||
|
|
return fmt.Errorf("read kernel: %w", err)
|
||
|
|
}
|
||
|
|
kernelText = string(data)
|
||
|
|
}
|
||
|
|
sandwich := kbText != "" && kernelText != ""
|
||
|
|
|
||
|
|
// Load model
|
||
|
|
slog.Info("chat: loading model", "path", chatModelPath)
|
||
|
|
backend, err := ml.NewMLXBackend(chatModelPath)
|
||
|
|
if err != nil {
|
||
|
|
return fmt.Errorf("load model: %w", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
opts := ml.GenOpts{
|
||
|
|
Temperature: chatTemp,
|
||
|
|
MaxTokens: chatMaxTokens,
|
||
|
|
}
|
||
|
|
|
||
|
|
// Conversation state
|
||
|
|
var history []ml.Message
|
||
|
|
if chatSystem != "" {
|
||
|
|
history = append(history, ml.Message{Role: "system", Content: chatSystem})
|
||
|
|
}
|
||
|
|
|
||
|
|
// Track saved conversations for JSONL output
|
||
|
|
var savedConversations [][]ml.Message
|
||
|
|
|
||
|
|
fmt.Println("Chat started. Type /quit to exit, /help for commands.")
|
||
|
|
if sandwich {
|
||
|
|
fmt.Println("Sandwich signing enabled (KB + kernel)")
|
||
|
|
}
|
||
|
|
if chatOutput != "" {
|
||
|
|
fmt.Printf("Capturing to: %s\n", chatOutput)
|
||
|
|
}
|
||
|
|
fmt.Println()
|
||
|
|
|
||
|
|
scanner := bufio.NewScanner(os.Stdin)
|
||
|
|
scanner.Buffer(make([]byte, 1<<20), 1<<20) // 1MB input buffer
|
||
|
|
|
||
|
|
for {
|
||
|
|
fmt.Print("you> ")
|
||
|
|
if !scanner.Scan() {
|
||
|
|
// EOF (Ctrl+D)
|
||
|
|
break
|
||
|
|
}
|
||
|
|
|
||
|
|
input := strings.TrimSpace(scanner.Text())
|
||
|
|
if input == "" {
|
||
|
|
continue
|
||
|
|
}
|
||
|
|
|
||
|
|
// Handle commands
|
||
|
|
if strings.HasPrefix(input, "/") {
|
||
|
|
cmd := strings.Fields(input)
|
||
|
|
switch cmd[0] {
|
||
|
|
case "/quit", "/exit":
|
||
|
|
goto done
|
||
|
|
case "/save":
|
||
|
|
if chatOutput == "" {
|
||
|
|
fmt.Println("No --output file specified. Use --output to enable saving.")
|
||
|
|
continue
|
||
|
|
}
|
||
|
|
if len(history) > 0 {
|
||
|
|
savedConversations = append(savedConversations, cloneMessages(history))
|
||
|
|
fmt.Printf("Saved conversation (%d messages)\n", len(history))
|
||
|
|
}
|
||
|
|
continue
|
||
|
|
case "/clear":
|
||
|
|
sysPrompt := ""
|
||
|
|
for _, m := range history {
|
||
|
|
if m.Role == "system" {
|
||
|
|
sysPrompt = m.Content
|
||
|
|
break
|
||
|
|
}
|
||
|
|
}
|
||
|
|
history = nil
|
||
|
|
if sysPrompt != "" {
|
||
|
|
history = append(history, ml.Message{Role: "system", Content: sysPrompt})
|
||
|
|
}
|
||
|
|
fmt.Println("Conversation cleared.")
|
||
|
|
continue
|
||
|
|
case "/system":
|
||
|
|
if len(cmd) < 2 {
|
||
|
|
fmt.Println("Usage: /system <prompt text>")
|
||
|
|
continue
|
||
|
|
}
|
||
|
|
sysText := strings.TrimPrefix(input, "/system ")
|
||
|
|
// Replace existing system prompt or add new one
|
||
|
|
found := false
|
||
|
|
for i, m := range history {
|
||
|
|
if m.Role == "system" {
|
||
|
|
history[i].Content = sysText
|
||
|
|
found = true
|
||
|
|
break
|
||
|
|
}
|
||
|
|
}
|
||
|
|
if !found {
|
||
|
|
// Prepend system message
|
||
|
|
history = append([]ml.Message{{Role: "system", Content: sysText}}, history...)
|
||
|
|
}
|
||
|
|
fmt.Printf("System prompt set (%d chars)\n", len(sysText))
|
||
|
|
continue
|
||
|
|
case "/undo":
|
||
|
|
// Remove last user+assistant pair
|
||
|
|
if len(history) >= 2 {
|
||
|
|
last := history[len(history)-1]
|
||
|
|
secondLast := history[len(history)-2]
|
||
|
|
if secondLast.Role == "user" && last.Role == "assistant" {
|
||
|
|
history = history[:len(history)-2]
|
||
|
|
fmt.Println("Last exchange removed.")
|
||
|
|
} else {
|
||
|
|
fmt.Println("Cannot undo: last messages are not a user/assistant pair.")
|
||
|
|
}
|
||
|
|
} else {
|
||
|
|
fmt.Println("Nothing to undo.")
|
||
|
|
}
|
||
|
|
continue
|
||
|
|
case "/help":
|
||
|
|
fmt.Println("Commands:")
|
||
|
|
fmt.Println(" /quit, /exit End session and save")
|
||
|
|
fmt.Println(" /save Save conversation so far")
|
||
|
|
fmt.Println(" /clear Clear conversation history")
|
||
|
|
fmt.Println(" /system <text> Set system prompt")
|
||
|
|
fmt.Println(" /undo Remove last exchange")
|
||
|
|
fmt.Println(" /help Show this help")
|
||
|
|
continue
|
||
|
|
default:
|
||
|
|
fmt.Printf("Unknown command: %s (try /help)\n", cmd[0])
|
||
|
|
continue
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// Add user message
|
||
|
|
history = append(history, ml.Message{Role: "user", Content: input})
|
||
|
|
|
||
|
|
// Generate response
|
||
|
|
genStart := time.Now()
|
||
|
|
fmt.Print("\nassistant> ")
|
||
|
|
|
||
|
|
var response strings.Builder
|
||
|
|
err := backend.ChatStream(cmd.Context(), history, opts, func(token string) error {
|
||
|
|
fmt.Print(token)
|
||
|
|
response.WriteString(token)
|
||
|
|
return nil
|
||
|
|
})
|
||
|
|
fmt.Println()
|
||
|
|
|
||
|
|
if err != nil {
|
||
|
|
slog.Error("chat: generation failed", "error", err)
|
||
|
|
// Remove the failed user message
|
||
|
|
history = history[:len(history)-1]
|
||
|
|
continue
|
||
|
|
}
|
||
|
|
|
||
|
|
elapsed := time.Since(genStart)
|
||
|
|
responseText := response.String()
|
||
|
|
history = append(history, ml.Message{Role: "assistant", Content: responseText})
|
||
|
|
|
||
|
|
slog.Debug("chat: response generated",
|
||
|
|
"chars", len(responseText),
|
||
|
|
"duration", elapsed.Round(time.Millisecond),
|
||
|
|
)
|
||
|
|
|
||
|
|
// Periodic cleanup
|
||
|
|
if len(history)%8 == 0 {
|
||
|
|
runtime.GC()
|
||
|
|
}
|
||
|
|
|
||
|
|
fmt.Println()
|
||
|
|
}
|
||
|
|
|
||
|
|
done:
|
||
|
|
fmt.Println()
|
||
|
|
|
||
|
|
// Save final conversation if output is specified
|
||
|
|
if chatOutput != "" && len(history) > 0 {
|
||
|
|
// Include current conversation if not already saved
|
||
|
|
savedConversations = append(savedConversations, history)
|
||
|
|
|
||
|
|
if err := writeChatJSONL(chatOutput, savedConversations, sandwich, kbText, kernelText); err != nil {
|
||
|
|
return fmt.Errorf("save conversation: %w", err)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
|
||
|
|
// writeChatJSONL writes conversations to JSONL file.
|
||
|
|
// If sandwich is true, wraps user messages with KB + kernel signing.
|
||
|
|
func writeChatJSONL(path string, conversations [][]ml.Message, sandwich bool, kb, kernel string) error {
|
||
|
|
f, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644)
|
||
|
|
if err != nil {
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
defer f.Close()
|
||
|
|
|
||
|
|
encoder := json.NewEncoder(f)
|
||
|
|
written := 0
|
||
|
|
|
||
|
|
for _, conv := range conversations {
|
||
|
|
// Extract user/assistant pairs (skip system messages for training output)
|
||
|
|
var messages []ml.Message
|
||
|
|
for _, m := range conv {
|
||
|
|
if m.Role == "system" {
|
||
|
|
continue
|
||
|
|
}
|
||
|
|
messages = append(messages, m)
|
||
|
|
}
|
||
|
|
|
||
|
|
if len(messages) < 2 {
|
||
|
|
continue
|
||
|
|
}
|
||
|
|
|
||
|
|
if sandwich {
|
||
|
|
// Apply sandwich signing to user messages
|
||
|
|
messages = applySandwichSigning(messages, kb, kernel)
|
||
|
|
}
|
||
|
|
|
||
|
|
record := struct {
|
||
|
|
Messages []ml.Message `json:"messages"`
|
||
|
|
}{Messages: messages}
|
||
|
|
|
||
|
|
if err := encoder.Encode(record); err != nil {
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
written++
|
||
|
|
}
|
||
|
|
|
||
|
|
slog.Info("chat: saved conversations",
|
||
|
|
"file", path,
|
||
|
|
"conversations", written,
|
||
|
|
"sandwich", sandwich,
|
||
|
|
)
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
|
||
|
|
// applySandwichSigning wraps user messages with KB preamble and kernel postfix.
|
||
|
|
func applySandwichSigning(messages []ml.Message, kb, kernel string) []ml.Message {
|
||
|
|
signed := make([]ml.Message, len(messages))
|
||
|
|
copy(signed, messages)
|
||
|
|
|
||
|
|
for i := range signed {
|
||
|
|
if signed[i].Role == "user" {
|
||
|
|
signed[i].Content = buildSandwich(kb, signed[i].Content, kernel)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
return signed
|
||
|
|
}
|
||
|
|
|
||
|
|
// cloneMessages creates a deep copy of a message slice.
|
||
|
|
func cloneMessages(msgs []ml.Message) []ml.Message {
|
||
|
|
clone := make([]ml.Message, len(msgs))
|
||
|
|
copy(clone, msgs)
|
||
|
|
return clone
|
||
|
|
}
|