Compare commits
1 commit
dev
...
docs/cli-c
| Author | SHA1 | Date | |
|---|---|---|---|
| 36f790ed22 |
118 changed files with 3096 additions and 3995 deletions
|
|
@ -3,4 +3,4 @@ version: '3'
|
||||||
tasks:
|
tasks:
|
||||||
build:
|
build:
|
||||||
cmds:
|
cmds:
|
||||||
- go build -o build/bin/core .
|
- go build -o build/bin/core cmd/app/main.go
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@ import (
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"forge.lthn.ai/core/go-scm/agentci"
|
"forge.lthn.ai/core/go/pkg/agentci"
|
||||||
"forge.lthn.ai/core/go/pkg/cli"
|
"forge.lthn.ai/core/go/pkg/cli"
|
||||||
"forge.lthn.ai/core/go/pkg/config"
|
"forge.lthn.ai/core/go/pkg/config"
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"forge.lthn.ai/core/go-ai/agentic"
|
"forge.lthn.ai/core/go/pkg/agentic"
|
||||||
"forge.lthn.ai/core/go/pkg/cli"
|
"forge.lthn.ai/core/go/pkg/cli"
|
||||||
"forge.lthn.ai/core/go/pkg/i18n"
|
"forge.lthn.ai/core/go/pkg/i18n"
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"forge.lthn.ai/core/go-ai/ai"
|
"forge.lthn.ai/core/go/pkg/ai"
|
||||||
"forge.lthn.ai/core/go/pkg/cli"
|
"forge.lthn.ai/core/go/pkg/cli"
|
||||||
"forge.lthn.ai/core/go/pkg/i18n"
|
"forge.lthn.ai/core/go/pkg/i18n"
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -9,8 +9,8 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"forge.lthn.ai/core/go-ai/agentic"
|
"forge.lthn.ai/core/go/pkg/agentic"
|
||||||
"forge.lthn.ai/core/go-ai/ai"
|
"forge.lthn.ai/core/go/pkg/ai"
|
||||||
"forge.lthn.ai/core/go/pkg/cli"
|
"forge.lthn.ai/core/go/pkg/cli"
|
||||||
"forge.lthn.ai/core/go/pkg/i18n"
|
"forge.lthn.ai/core/go/pkg/i18n"
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -6,8 +6,8 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"forge.lthn.ai/core/go-ai/agentic"
|
"forge.lthn.ai/core/go/pkg/agentic"
|
||||||
"forge.lthn.ai/core/go-ai/ai"
|
"forge.lthn.ai/core/go/pkg/ai"
|
||||||
"forge.lthn.ai/core/go/pkg/cli"
|
"forge.lthn.ai/core/go/pkg/cli"
|
||||||
"forge.lthn.ai/core/go/pkg/i18n"
|
"forge.lthn.ai/core/go/pkg/i18n"
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"forge.lthn.ai/core/go/pkg/cli"
|
"forge.lthn.ai/core/go/pkg/cli"
|
||||||
"forge.lthn.ai/core/go-scm/collect"
|
"forge.lthn.ai/core/go/pkg/collect"
|
||||||
"forge.lthn.ai/core/go/pkg/i18n"
|
"forge.lthn.ai/core/go/pkg/i18n"
|
||||||
"forge.lthn.ai/core/go/pkg/io"
|
"forge.lthn.ai/core/go/pkg/io"
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"forge.lthn.ai/core/go/pkg/cli"
|
"forge.lthn.ai/core/go/pkg/cli"
|
||||||
"forge.lthn.ai/core/go-scm/collect"
|
"forge.lthn.ai/core/go/pkg/collect"
|
||||||
"forge.lthn.ai/core/go/pkg/i18n"
|
"forge.lthn.ai/core/go/pkg/i18n"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"forge.lthn.ai/core/go/pkg/cli"
|
"forge.lthn.ai/core/go/pkg/cli"
|
||||||
collectpkg "forge.lthn.ai/core/go-scm/collect"
|
collectpkg "forge.lthn.ai/core/go/pkg/collect"
|
||||||
"forge.lthn.ai/core/go/pkg/i18n"
|
"forge.lthn.ai/core/go/pkg/i18n"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"forge.lthn.ai/core/go/pkg/cli"
|
"forge.lthn.ai/core/go/pkg/cli"
|
||||||
"forge.lthn.ai/core/go-scm/collect"
|
"forge.lthn.ai/core/go/pkg/collect"
|
||||||
"forge.lthn.ai/core/go/pkg/i18n"
|
"forge.lthn.ai/core/go/pkg/i18n"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"forge.lthn.ai/core/go/pkg/cli"
|
"forge.lthn.ai/core/go/pkg/cli"
|
||||||
"forge.lthn.ai/core/go-scm/collect"
|
"forge.lthn.ai/core/go/pkg/collect"
|
||||||
"forge.lthn.ai/core/go/pkg/i18n"
|
"forge.lthn.ai/core/go/pkg/i18n"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ import (
|
||||||
"context"
|
"context"
|
||||||
|
|
||||||
"forge.lthn.ai/core/go/pkg/cli"
|
"forge.lthn.ai/core/go/pkg/cli"
|
||||||
"forge.lthn.ai/core/go-scm/collect"
|
"forge.lthn.ai/core/go/pkg/collect"
|
||||||
"forge.lthn.ai/core/go/pkg/i18n"
|
"forge.lthn.ai/core/go/pkg/i18n"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ import (
|
||||||
"context"
|
"context"
|
||||||
|
|
||||||
"forge.lthn.ai/core/go/pkg/cli"
|
"forge.lthn.ai/core/go/pkg/cli"
|
||||||
"forge.lthn.ai/core/go-scm/collect"
|
"forge.lthn.ai/core/go/pkg/collect"
|
||||||
"forge.lthn.ai/core/go/pkg/i18n"
|
"forge.lthn.ai/core/go/pkg/i18n"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ import (
|
||||||
"context"
|
"context"
|
||||||
|
|
||||||
"forge.lthn.ai/core/go/pkg/cli"
|
"forge.lthn.ai/core/go/pkg/cli"
|
||||||
"forge.lthn.ai/core/go-scm/collect"
|
"forge.lthn.ai/core/go/pkg/collect"
|
||||||
"forge.lthn.ai/core/go/pkg/i18n"
|
"forge.lthn.ai/core/go/pkg/i18n"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ import (
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
|
||||||
"forge.lthn.ai/core/go/pkg/cli"
|
"forge.lthn.ai/core/go/pkg/cli"
|
||||||
"forge.lthn.ai/core/go-crypt/crypt"
|
"forge.lthn.ai/core/go/pkg/crypt"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Checksum command flags
|
// Checksum command flags
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"forge.lthn.ai/core/go/pkg/cli"
|
"forge.lthn.ai/core/go/pkg/cli"
|
||||||
"forge.lthn.ai/core/go-crypt/crypt"
|
"forge.lthn.ai/core/go/pkg/crypt"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Encrypt command flags
|
// Encrypt command flags
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"forge.lthn.ai/core/go/pkg/cli"
|
"forge.lthn.ai/core/go/pkg/cli"
|
||||||
"forge.lthn.ai/core/go-crypt/crypt"
|
"forge.lthn.ai/core/go/pkg/crypt"
|
||||||
"golang.org/x/crypto/bcrypt"
|
"golang.org/x/crypto/bcrypt"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -3,21 +3,13 @@ package daemon
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
|
||||||
"os"
|
"os"
|
||||||
"os/exec"
|
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
"syscall"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"forge.lthn.ai/core/go/pkg/cli"
|
"forge.lthn.ai/core/go/pkg/cli"
|
||||||
"forge.lthn.ai/core/go/pkg/log"
|
"forge.lthn.ai/core/go/pkg/log"
|
||||||
"forge.lthn.ai/core/go/pkg/process"
|
"forge.lthn.ai/core/go/pkg/mcp"
|
||||||
"forge.lthn.ai/core/go-ai/mcp"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
|
|
@ -55,6 +47,7 @@ func DefaultConfig() Config {
|
||||||
}
|
}
|
||||||
|
|
||||||
// ConfigFromEnv loads configuration from environment variables.
|
// ConfigFromEnv loads configuration from environment variables.
|
||||||
|
// Environment variables override default values.
|
||||||
func ConfigFromEnv() Config {
|
func ConfigFromEnv() Config {
|
||||||
cfg := DefaultConfig()
|
cfg := DefaultConfig()
|
||||||
|
|
||||||
|
|
@ -74,207 +67,40 @@ func ConfigFromEnv() Config {
|
||||||
return cfg
|
return cfg
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddDaemonCommand adds the 'daemon' command group to the root.
|
// AddDaemonCommand adds the 'daemon' command to the root.
|
||||||
func AddDaemonCommand(root *cli.Command) {
|
func AddDaemonCommand(root *cli.Command) {
|
||||||
cfg := ConfigFromEnv()
|
cfg := ConfigFromEnv()
|
||||||
|
|
||||||
daemonCmd := cli.NewGroup(
|
daemonCmd := cli.NewCommand(
|
||||||
"daemon",
|
"daemon",
|
||||||
"Manage the core daemon",
|
"Start the core daemon",
|
||||||
"Manage the core background daemon which provides long-running services.\n\n"+
|
"Starts the core daemon which provides long-running services like MCP.\n\n"+
|
||||||
"Subcommands:\n"+
|
"The daemon can be configured via environment variables or flags:\n"+
|
||||||
" start - Start the daemon in the background\n"+
|
" CORE_MCP_TRANSPORT - MCP transport type (stdio, tcp, socket)\n"+
|
||||||
" stop - Stop the running daemon\n"+
|
" CORE_MCP_ADDR - MCP address/path (e.g., :9100, /tmp/mcp.sock)\n"+
|
||||||
" status - Show daemon status\n"+
|
" CORE_HEALTH_ADDR - Health check endpoint address\n"+
|
||||||
" run - Run in foreground (for development/debugging)",
|
" CORE_PID_FILE - PID file path for single-instance enforcement",
|
||||||
|
func(cmd *cli.Command, args []string) error {
|
||||||
|
return runDaemon(cfg)
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
// Persistent flags inherited by all subcommands
|
// Flags override environment variables
|
||||||
cli.PersistentStringFlag(daemonCmd, &cfg.MCPTransport, "mcp-transport", "t", cfg.MCPTransport,
|
cli.StringFlag(daemonCmd, &cfg.MCPTransport, "mcp-transport", "t", cfg.MCPTransport,
|
||||||
"MCP transport type (stdio, tcp, socket)")
|
"MCP transport type (stdio, tcp, socket)")
|
||||||
cli.PersistentStringFlag(daemonCmd, &cfg.MCPAddr, "mcp-addr", "a", cfg.MCPAddr,
|
cli.StringFlag(daemonCmd, &cfg.MCPAddr, "mcp-addr", "a", cfg.MCPAddr,
|
||||||
"MCP listen address (e.g., :9100 or /tmp/mcp.sock)")
|
"MCP listen address (e.g., :9100 or /tmp/mcp.sock)")
|
||||||
cli.PersistentStringFlag(daemonCmd, &cfg.HealthAddr, "health-addr", "", cfg.HealthAddr,
|
cli.StringFlag(daemonCmd, &cfg.HealthAddr, "health-addr", "", cfg.HealthAddr,
|
||||||
"Health check endpoint address (empty to disable)")
|
"Health check endpoint address (empty to disable)")
|
||||||
cli.PersistentStringFlag(daemonCmd, &cfg.PIDFile, "pid-file", "", cfg.PIDFile,
|
cli.StringFlag(daemonCmd, &cfg.PIDFile, "pid-file", "", cfg.PIDFile,
|
||||||
"PID file path (empty to disable)")
|
"PID file path (empty to disable)")
|
||||||
|
|
||||||
// --- Subcommands ---
|
|
||||||
|
|
||||||
startCmd := cli.NewCommand("start", "Start the daemon in the background",
|
|
||||||
"Re-executes the core binary as a background daemon process.\n"+
|
|
||||||
"The daemon PID is written to the PID file for later management.",
|
|
||||||
func(cmd *cli.Command, args []string) error {
|
|
||||||
return runStart(cfg)
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
stopCmd := cli.NewCommand("stop", "Stop the running daemon",
|
|
||||||
"Sends SIGTERM to the daemon process identified by the PID file.\n"+
|
|
||||||
"Waits for graceful shutdown before returning.",
|
|
||||||
func(cmd *cli.Command, args []string) error {
|
|
||||||
return runStop(cfg)
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
statusCmd := cli.NewCommand("status", "Show daemon status",
|
|
||||||
"Checks if the daemon is running and queries its health endpoint.",
|
|
||||||
func(cmd *cli.Command, args []string) error {
|
|
||||||
return runStatus(cfg)
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
runCmd := cli.NewCommand("run", "Run the daemon in the foreground",
|
|
||||||
"Runs the daemon in the current terminal (blocks until SIGINT/SIGTERM).\n"+
|
|
||||||
"Useful for development, debugging, or running under a process manager.",
|
|
||||||
func(cmd *cli.Command, args []string) error {
|
|
||||||
return runForeground(cfg)
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
daemonCmd.AddCommand(startCmd, stopCmd, statusCmd, runCmd)
|
|
||||||
root.AddCommand(daemonCmd)
|
root.AddCommand(daemonCmd)
|
||||||
}
|
}
|
||||||
|
|
||||||
// runStart re-execs the current binary as a detached daemon process.
|
// runDaemon starts the daemon with the given configuration.
|
||||||
func runStart(cfg Config) error {
|
func runDaemon(cfg Config) error {
|
||||||
// Check if already running
|
// Set daemon mode environment for child processes
|
||||||
if pid, running := readPID(cfg.PIDFile); running {
|
|
||||||
return fmt.Errorf("daemon already running (PID %d)", pid)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Find the current binary
|
|
||||||
exe, err := os.Executable()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to find executable: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Build args for the foreground run command
|
|
||||||
args := []string{"daemon", "run",
|
|
||||||
"--mcp-transport", cfg.MCPTransport,
|
|
||||||
"--mcp-addr", cfg.MCPAddr,
|
|
||||||
"--health-addr", cfg.HealthAddr,
|
|
||||||
"--pid-file", cfg.PIDFile,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Launch detached child with CORE_DAEMON=1
|
|
||||||
cmd := exec.Command(exe, args...)
|
|
||||||
cmd.Env = append(os.Environ(), "CORE_DAEMON=1")
|
|
||||||
cmd.Stdout = nil
|
|
||||||
cmd.Stderr = nil
|
|
||||||
cmd.Stdin = nil
|
|
||||||
|
|
||||||
// Detach from parent process group
|
|
||||||
cmd.SysProcAttr = &syscall.SysProcAttr{
|
|
||||||
Setsid: true,
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := cmd.Start(); err != nil {
|
|
||||||
return fmt.Errorf("failed to start daemon: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
pid := cmd.Process.Pid
|
|
||||||
|
|
||||||
// Release the child process so it runs independently
|
|
||||||
_ = cmd.Process.Release()
|
|
||||||
|
|
||||||
// Wait briefly for the health endpoint to come up
|
|
||||||
if cfg.HealthAddr != "" {
|
|
||||||
ready := waitForHealth(cfg.HealthAddr, 5*time.Second)
|
|
||||||
if ready {
|
|
||||||
log.Info("Daemon started", "pid", pid, "health", cfg.HealthAddr)
|
|
||||||
} else {
|
|
||||||
log.Info("Daemon started (health check not yet ready)", "pid", pid)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
log.Info("Daemon started", "pid", pid)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// runStop sends SIGTERM to the daemon process.
|
|
||||||
func runStop(cfg Config) error {
|
|
||||||
pid, running := readPID(cfg.PIDFile)
|
|
||||||
if !running {
|
|
||||||
log.Info("Daemon is not running")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
proc, err := os.FindProcess(pid)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to find process %d: %w", pid, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Info("Stopping daemon", "pid", pid)
|
|
||||||
if err := proc.Signal(syscall.SIGTERM); err != nil {
|
|
||||||
return fmt.Errorf("failed to send SIGTERM to PID %d: %w", pid, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Wait for the process to exit (poll PID file removal)
|
|
||||||
deadline := time.Now().Add(30 * time.Second)
|
|
||||||
for time.Now().Before(deadline) {
|
|
||||||
if _, still := readPID(cfg.PIDFile); !still {
|
|
||||||
log.Info("Daemon stopped")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
time.Sleep(250 * time.Millisecond)
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Warn("Daemon did not stop within 30s, sending SIGKILL")
|
|
||||||
_ = proc.Signal(syscall.SIGKILL)
|
|
||||||
|
|
||||||
// Clean up stale PID file
|
|
||||||
_ = os.Remove(cfg.PIDFile)
|
|
||||||
log.Info("Daemon killed")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// runStatus checks daemon status via PID and health endpoint.
|
|
||||||
func runStatus(cfg Config) error {
|
|
||||||
pid, running := readPID(cfg.PIDFile)
|
|
||||||
if !running {
|
|
||||||
fmt.Println("Daemon is not running")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
fmt.Printf("Daemon is running (PID %d)\n", pid)
|
|
||||||
|
|
||||||
// Query health endpoint if configured
|
|
||||||
if cfg.HealthAddr != "" {
|
|
||||||
healthURL := fmt.Sprintf("http://%s/health", cfg.HealthAddr)
|
|
||||||
resp, err := http.Get(healthURL)
|
|
||||||
if err != nil {
|
|
||||||
fmt.Printf("Health: unreachable (%v)\n", err)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
if resp.StatusCode == http.StatusOK {
|
|
||||||
fmt.Println("Health: ok")
|
|
||||||
} else {
|
|
||||||
fmt.Printf("Health: unhealthy (HTTP %d)\n", resp.StatusCode)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check readiness
|
|
||||||
readyURL := fmt.Sprintf("http://%s/ready", cfg.HealthAddr)
|
|
||||||
resp2, err := http.Get(readyURL)
|
|
||||||
if err == nil {
|
|
||||||
defer resp2.Body.Close()
|
|
||||||
if resp2.StatusCode == http.StatusOK {
|
|
||||||
fmt.Println("Ready: yes")
|
|
||||||
} else {
|
|
||||||
fmt.Println("Ready: no")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// runForeground runs the daemon in the current process (blocking).
|
|
||||||
// This is what `core daemon run` and the detached child process execute.
|
|
||||||
func runForeground(cfg Config) error {
|
|
||||||
os.Setenv("CORE_DAEMON", "1")
|
os.Setenv("CORE_DAEMON", "1")
|
||||||
|
|
||||||
log.Info("Starting daemon",
|
log.Info("Starting daemon",
|
||||||
|
|
@ -301,61 +127,33 @@ func runForeground(cfg Config) error {
|
||||||
return fmt.Errorf("failed to start daemon: %w", err)
|
return fmt.Errorf("failed to start daemon: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create supervisor for managed services
|
|
||||||
sup := process.NewSupervisor(nil) // nil service — we only supervise Go functions
|
|
||||||
|
|
||||||
// Register MCP server as a supervised service
|
|
||||||
sup.RegisterFunc(process.GoSpec{
|
|
||||||
Name: "mcp",
|
|
||||||
Func: func(ctx context.Context) error {
|
|
||||||
return startMCP(ctx, mcpSvc, cfg)
|
|
||||||
},
|
|
||||||
Restart: process.RestartPolicy{
|
|
||||||
Delay: 3 * time.Second,
|
|
||||||
MaxRestarts: -1, // Unlimited restarts
|
|
||||||
},
|
|
||||||
})
|
|
||||||
|
|
||||||
// Start supervised services
|
|
||||||
sup.Start()
|
|
||||||
|
|
||||||
// Mark as ready
|
|
||||||
daemon.SetReady(true)
|
|
||||||
|
|
||||||
// Add supervisor status to health checks
|
|
||||||
daemon.AddHealthCheck(func() error {
|
|
||||||
statuses := sup.Statuses()
|
|
||||||
for name, status := range statuses {
|
|
||||||
if !status.Running {
|
|
||||||
return fmt.Errorf("service %s is not running (restarts: %d)", name, status.RestartCount)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
|
|
||||||
log.Info("Daemon ready",
|
|
||||||
"pid", os.Getpid(),
|
|
||||||
"health", daemon.HealthAddr(),
|
|
||||||
"services", strings.Join(sup.UnitNames(), ", "),
|
|
||||||
)
|
|
||||||
|
|
||||||
// Print supervised service status as JSON for machine consumption
|
|
||||||
statuses := sup.Statuses()
|
|
||||||
if data, err := json.Marshal(statuses); err == nil {
|
|
||||||
log.Debug("Supervised services", "statuses", string(data))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get context that cancels on SIGINT/SIGTERM
|
// Get context that cancels on SIGINT/SIGTERM
|
||||||
ctx := cli.Context()
|
ctx := cli.Context()
|
||||||
|
|
||||||
// Wait for shutdown signal
|
// Start MCP server in background
|
||||||
<-ctx.Done()
|
mcpErrCh := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
mcpErrCh <- startMCP(ctx, mcpSvc, cfg)
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Mark as ready
|
||||||
|
daemon.SetReady(true)
|
||||||
|
log.Info("Daemon ready",
|
||||||
|
"pid", os.Getpid(),
|
||||||
|
"health", daemon.HealthAddr(),
|
||||||
|
)
|
||||||
|
|
||||||
|
// Wait for shutdown signal or MCP error
|
||||||
|
select {
|
||||||
|
case err := <-mcpErrCh:
|
||||||
|
if err != nil && ctx.Err() == nil {
|
||||||
|
log.Error("MCP server error", "err", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
case <-ctx.Done():
|
||||||
log.Info("Shutting down daemon")
|
log.Info("Shutting down daemon")
|
||||||
|
}
|
||||||
|
|
||||||
// Stop supervised services first
|
|
||||||
sup.Stop()
|
|
||||||
|
|
||||||
// Then stop the daemon (releases PID, stops health server)
|
|
||||||
return daemon.Stop()
|
return daemon.Stop()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -378,50 +176,3 @@ func startMCP(ctx context.Context, svc *mcp.Service, cfg Config) error {
|
||||||
return fmt.Errorf("unknown MCP transport: %s (valid: stdio, tcp, socket)", cfg.MCPTransport)
|
return fmt.Errorf("unknown MCP transport: %s (valid: stdio, tcp, socket)", cfg.MCPTransport)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// --- Helpers ---
|
|
||||||
|
|
||||||
// readPID reads the PID file and checks if the process is still running.
|
|
||||||
func readPID(path string) (int, bool) {
|
|
||||||
data, err := os.ReadFile(path)
|
|
||||||
if err != nil {
|
|
||||||
return 0, false
|
|
||||||
}
|
|
||||||
|
|
||||||
pid, err := strconv.Atoi(strings.TrimSpace(string(data)))
|
|
||||||
if err != nil || pid <= 0 {
|
|
||||||
return 0, false
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if process is actually running
|
|
||||||
proc, err := os.FindProcess(pid)
|
|
||||||
if err != nil {
|
|
||||||
return pid, false
|
|
||||||
}
|
|
||||||
|
|
||||||
// Signal 0 tests if the process exists without actually sending a signal
|
|
||||||
if err := proc.Signal(syscall.Signal(0)); err != nil {
|
|
||||||
return pid, false
|
|
||||||
}
|
|
||||||
|
|
||||||
return pid, true
|
|
||||||
}
|
|
||||||
|
|
||||||
// waitForHealth polls the health endpoint until it responds or timeout.
|
|
||||||
func waitForHealth(addr string, timeout time.Duration) bool {
|
|
||||||
deadline := time.Now().Add(timeout)
|
|
||||||
url := fmt.Sprintf("http://%s/health", addr)
|
|
||||||
|
|
||||||
for time.Now().Before(deadline) {
|
|
||||||
resp, err := http.Get(url)
|
|
||||||
if err == nil {
|
|
||||||
resp.Body.Close()
|
|
||||||
if resp.StatusCode == http.StatusOK {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
time.Sleep(200 * time.Millisecond)
|
|
||||||
}
|
|
||||||
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,7 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"forge.lthn.ai/core/go-devops/ansible"
|
"forge.lthn.ai/core/go/pkg/ansible"
|
||||||
"forge.lthn.ai/core/go/pkg/cli"
|
"forge.lthn.ai/core/go/pkg/cli"
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@ import (
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
"forge.lthn.ai/core/go/pkg/cli"
|
"forge.lthn.ai/core/go/pkg/cli"
|
||||||
"forge.lthn.ai/core/go-devops/deploy/coolify"
|
"forge.lthn.ai/core/go/pkg/deploy/coolify"
|
||||||
"forge.lthn.ai/core/go/pkg/i18n"
|
"forge.lthn.ai/core/go/pkg/i18n"
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,7 @@ import (
|
||||||
|
|
||||||
"forge.lthn.ai/core/go/pkg/cli"
|
"forge.lthn.ai/core/go/pkg/cli"
|
||||||
core "forge.lthn.ai/core/go/pkg/framework/core"
|
core "forge.lthn.ai/core/go/pkg/framework/core"
|
||||||
"forge.lthn.ai/core/go-scm/git"
|
"forge.lthn.ai/core/go/pkg/git"
|
||||||
"forge.lthn.ai/core/go/pkg/i18n"
|
"forge.lthn.ai/core/go/pkg/i18n"
|
||||||
"forge.lthn.ai/core/go/pkg/io"
|
"forge.lthn.ai/core/go/pkg/io"
|
||||||
"forge.lthn.ai/core/go/pkg/repos"
|
"forge.lthn.ai/core/go/pkg/repos"
|
||||||
|
|
|
||||||
|
|
@ -3,9 +3,9 @@ package dev
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
|
||||||
"forge.lthn.ai/core/go-ai/agentic"
|
"forge.lthn.ai/core/go/pkg/agentic"
|
||||||
"forge.lthn.ai/core/go/pkg/framework"
|
"forge.lthn.ai/core/go/pkg/framework"
|
||||||
"forge.lthn.ai/core/go-scm/git"
|
"forge.lthn.ai/core/go/pkg/git"
|
||||||
)
|
)
|
||||||
|
|
||||||
// WorkBundle contains the Core instance for dev work operations.
|
// WorkBundle contains the Core instance for dev work operations.
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ import (
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
|
||||||
"forge.lthn.ai/core/go/pkg/cli"
|
"forge.lthn.ai/core/go/pkg/cli"
|
||||||
"forge.lthn.ai/core/go-scm/git"
|
"forge.lthn.ai/core/go/pkg/git"
|
||||||
"forge.lthn.ai/core/go/pkg/i18n"
|
"forge.lthn.ai/core/go/pkg/i18n"
|
||||||
coreio "forge.lthn.ai/core/go/pkg/io"
|
coreio "forge.lthn.ai/core/go/pkg/io"
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,7 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"forge.lthn.ai/core/go/pkg/cli"
|
"forge.lthn.ai/core/go/pkg/cli"
|
||||||
"forge.lthn.ai/core/go-scm/git"
|
"forge.lthn.ai/core/go/pkg/git"
|
||||||
"forge.lthn.ai/core/go/pkg/i18n"
|
"forge.lthn.ai/core/go/pkg/i18n"
|
||||||
coreio "forge.lthn.ai/core/go/pkg/io"
|
coreio "forge.lthn.ai/core/go/pkg/io"
|
||||||
"forge.lthn.ai/core/go/pkg/log"
|
"forge.lthn.ai/core/go/pkg/log"
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"forge.lthn.ai/core/go/pkg/cli"
|
"forge.lthn.ai/core/go/pkg/cli"
|
||||||
"forge.lthn.ai/core/go-scm/git"
|
"forge.lthn.ai/core/go/pkg/git"
|
||||||
"forge.lthn.ai/core/go/pkg/i18n"
|
"forge.lthn.ai/core/go/pkg/i18n"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ import (
|
||||||
"os/exec"
|
"os/exec"
|
||||||
|
|
||||||
"forge.lthn.ai/core/go/pkg/cli"
|
"forge.lthn.ai/core/go/pkg/cli"
|
||||||
"forge.lthn.ai/core/go-scm/git"
|
"forge.lthn.ai/core/go/pkg/git"
|
||||||
"forge.lthn.ai/core/go/pkg/i18n"
|
"forge.lthn.ai/core/go/pkg/i18n"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ import (
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
|
||||||
"forge.lthn.ai/core/go/pkg/cli"
|
"forge.lthn.ai/core/go/pkg/cli"
|
||||||
"forge.lthn.ai/core/go-scm/git"
|
"forge.lthn.ai/core/go/pkg/git"
|
||||||
"forge.lthn.ai/core/go/pkg/i18n"
|
"forge.lthn.ai/core/go/pkg/i18n"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"forge.lthn.ai/core/go/pkg/cli"
|
"forge.lthn.ai/core/go/pkg/cli"
|
||||||
"forge.lthn.ai/core/go-devops/devops"
|
"forge.lthn.ai/core/go/pkg/devops"
|
||||||
"forge.lthn.ai/core/go/pkg/i18n"
|
"forge.lthn.ai/core/go/pkg/i18n"
|
||||||
"forge.lthn.ai/core/go/pkg/io"
|
"forge.lthn.ai/core/go/pkg/io"
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -7,9 +7,9 @@ import (
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"forge.lthn.ai/core/go-ai/agentic"
|
"forge.lthn.ai/core/go/pkg/agentic"
|
||||||
"forge.lthn.ai/core/go/pkg/cli"
|
"forge.lthn.ai/core/go/pkg/cli"
|
||||||
"forge.lthn.ai/core/go-scm/git"
|
"forge.lthn.ai/core/go/pkg/git"
|
||||||
"forge.lthn.ai/core/go/pkg/i18n"
|
"forge.lthn.ai/core/go/pkg/i18n"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -5,10 +5,10 @@ import (
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"forge.lthn.ai/core/go-ai/agentic"
|
"forge.lthn.ai/core/go/pkg/agentic"
|
||||||
"forge.lthn.ai/core/go/pkg/cli"
|
"forge.lthn.ai/core/go/pkg/cli"
|
||||||
"forge.lthn.ai/core/go/pkg/framework"
|
"forge.lthn.ai/core/go/pkg/framework"
|
||||||
"forge.lthn.ai/core/go-scm/git"
|
"forge.lthn.ai/core/go/pkg/git"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Tasks for dev service
|
// Tasks for dev service
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"forge.lthn.ai/core/go/pkg/cli"
|
"forge.lthn.ai/core/go/pkg/cli"
|
||||||
fg "forge.lthn.ai/core/go-scm/forge"
|
fg "forge.lthn.ai/core/go/pkg/forge"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Auth command flags.
|
// Auth command flags.
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"forge.lthn.ai/core/go/pkg/cli"
|
"forge.lthn.ai/core/go/pkg/cli"
|
||||||
fg "forge.lthn.ai/core/go-scm/forge"
|
fg "forge.lthn.ai/core/go/pkg/forge"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Config command flags.
|
// Config command flags.
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@ import (
|
||||||
forgejo "codeberg.org/mvdkleijn/forgejo-sdk/forgejo/v2"
|
forgejo "codeberg.org/mvdkleijn/forgejo-sdk/forgejo/v2"
|
||||||
|
|
||||||
"forge.lthn.ai/core/go/pkg/cli"
|
"forge.lthn.ai/core/go/pkg/cli"
|
||||||
fg "forge.lthn.ai/core/go-scm/forge"
|
fg "forge.lthn.ai/core/go/pkg/forge"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Issues command flags.
|
// Issues command flags.
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ import (
|
||||||
forgejo "codeberg.org/mvdkleijn/forgejo-sdk/forgejo/v2"
|
forgejo "codeberg.org/mvdkleijn/forgejo-sdk/forgejo/v2"
|
||||||
|
|
||||||
"forge.lthn.ai/core/go/pkg/cli"
|
"forge.lthn.ai/core/go/pkg/cli"
|
||||||
fg "forge.lthn.ai/core/go-scm/forge"
|
fg "forge.lthn.ai/core/go/pkg/forge"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Labels command flags.
|
// Labels command flags.
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ import (
|
||||||
forgejo "codeberg.org/mvdkleijn/forgejo-sdk/forgejo/v2"
|
forgejo "codeberg.org/mvdkleijn/forgejo-sdk/forgejo/v2"
|
||||||
|
|
||||||
"forge.lthn.ai/core/go/pkg/cli"
|
"forge.lthn.ai/core/go/pkg/cli"
|
||||||
fg "forge.lthn.ai/core/go-scm/forge"
|
fg "forge.lthn.ai/core/go/pkg/forge"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Migrate command flags.
|
// Migrate command flags.
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"forge.lthn.ai/core/go/pkg/cli"
|
"forge.lthn.ai/core/go/pkg/cli"
|
||||||
fg "forge.lthn.ai/core/go-scm/forge"
|
fg "forge.lthn.ai/core/go/pkg/forge"
|
||||||
)
|
)
|
||||||
|
|
||||||
// addOrgsCommand adds the 'orgs' subcommand for listing organisations.
|
// addOrgsCommand adds the 'orgs' subcommand for listing organisations.
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@ import (
|
||||||
forgejo "codeberg.org/mvdkleijn/forgejo-sdk/forgejo/v2"
|
forgejo "codeberg.org/mvdkleijn/forgejo-sdk/forgejo/v2"
|
||||||
|
|
||||||
"forge.lthn.ai/core/go/pkg/cli"
|
"forge.lthn.ai/core/go/pkg/cli"
|
||||||
fg "forge.lthn.ai/core/go-scm/forge"
|
fg "forge.lthn.ai/core/go/pkg/forge"
|
||||||
)
|
)
|
||||||
|
|
||||||
// PRs command flags.
|
// PRs command flags.
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ import (
|
||||||
forgejo "codeberg.org/mvdkleijn/forgejo-sdk/forgejo/v2"
|
forgejo "codeberg.org/mvdkleijn/forgejo-sdk/forgejo/v2"
|
||||||
|
|
||||||
"forge.lthn.ai/core/go/pkg/cli"
|
"forge.lthn.ai/core/go/pkg/cli"
|
||||||
fg "forge.lthn.ai/core/go-scm/forge"
|
fg "forge.lthn.ai/core/go/pkg/forge"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Repos command flags.
|
// Repos command flags.
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"forge.lthn.ai/core/go/pkg/cli"
|
"forge.lthn.ai/core/go/pkg/cli"
|
||||||
fg "forge.lthn.ai/core/go-scm/forge"
|
fg "forge.lthn.ai/core/go/pkg/forge"
|
||||||
)
|
)
|
||||||
|
|
||||||
// addStatusCommand adds the 'status' subcommand for instance info.
|
// addStatusCommand adds the 'status' subcommand for instance info.
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@ import (
|
||||||
forgejo "codeberg.org/mvdkleijn/forgejo-sdk/forgejo/v2"
|
forgejo "codeberg.org/mvdkleijn/forgejo-sdk/forgejo/v2"
|
||||||
|
|
||||||
"forge.lthn.ai/core/go/pkg/cli"
|
"forge.lthn.ai/core/go/pkg/cli"
|
||||||
fg "forge.lthn.ai/core/go-scm/forge"
|
fg "forge.lthn.ai/core/go/pkg/forge"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Sync command flags.
|
// Sync command flags.
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"forge.lthn.ai/core/go/pkg/cli"
|
"forge.lthn.ai/core/go/pkg/cli"
|
||||||
gt "forge.lthn.ai/core/go-scm/gitea"
|
gt "forge.lthn.ai/core/go/pkg/gitea"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Config command flags.
|
// Config command flags.
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@ import (
|
||||||
"code.gitea.io/sdk/gitea"
|
"code.gitea.io/sdk/gitea"
|
||||||
|
|
||||||
"forge.lthn.ai/core/go/pkg/cli"
|
"forge.lthn.ai/core/go/pkg/cli"
|
||||||
gt "forge.lthn.ai/core/go-scm/gitea"
|
gt "forge.lthn.ai/core/go/pkg/gitea"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Issues command flags.
|
// Issues command flags.
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"forge.lthn.ai/core/go/pkg/cli"
|
"forge.lthn.ai/core/go/pkg/cli"
|
||||||
gt "forge.lthn.ai/core/go-scm/gitea"
|
gt "forge.lthn.ai/core/go/pkg/gitea"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Mirror command flags.
|
// Mirror command flags.
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@ import (
|
||||||
sdk "code.gitea.io/sdk/gitea"
|
sdk "code.gitea.io/sdk/gitea"
|
||||||
|
|
||||||
"forge.lthn.ai/core/go/pkg/cli"
|
"forge.lthn.ai/core/go/pkg/cli"
|
||||||
gt "forge.lthn.ai/core/go-scm/gitea"
|
gt "forge.lthn.ai/core/go/pkg/gitea"
|
||||||
)
|
)
|
||||||
|
|
||||||
// PRs command flags.
|
// PRs command flags.
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"forge.lthn.ai/core/go/pkg/cli"
|
"forge.lthn.ai/core/go/pkg/cli"
|
||||||
gt "forge.lthn.ai/core/go-scm/gitea"
|
gt "forge.lthn.ai/core/go/pkg/gitea"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Repos command flags.
|
// Repos command flags.
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@ import (
|
||||||
"code.gitea.io/sdk/gitea"
|
"code.gitea.io/sdk/gitea"
|
||||||
|
|
||||||
"forge.lthn.ai/core/go/pkg/cli"
|
"forge.lthn.ai/core/go/pkg/cli"
|
||||||
gt "forge.lthn.ai/core/go-scm/gitea"
|
gt "forge.lthn.ai/core/go/pkg/gitea"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Sync command flags.
|
// Sync command flags.
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,7 @@ import (
|
||||||
"syscall"
|
"syscall"
|
||||||
|
|
||||||
"forge.lthn.ai/core/go/pkg/cli"
|
"forge.lthn.ai/core/go/pkg/cli"
|
||||||
"forge.lthn.ai/core/go-ai/mcp"
|
"forge.lthn.ai/core/go/pkg/mcp"
|
||||||
)
|
)
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
|
|
|
||||||
832
cmd/ml/chat.js
832
cmd/ml/chat.js
|
|
@ -1,832 +0,0 @@
|
||||||
// src/styles.ts
|
|
||||||
var chatStyles = `
|
|
||||||
:host {
|
|
||||||
display: flex;
|
|
||||||
flex-direction: column;
|
|
||||||
background: var(--lem-bg, #1a1a1e);
|
|
||||||
color: var(--lem-text, #e0e0e0);
|
|
||||||
font-family: var(--lem-font, system-ui, -apple-system, sans-serif);
|
|
||||||
font-size: 14px;
|
|
||||||
line-height: 1.5;
|
|
||||||
border-radius: 12px;
|
|
||||||
overflow: hidden;
|
|
||||||
border: 1px solid rgba(255, 255, 255, 0.08);
|
|
||||||
}
|
|
||||||
|
|
||||||
.header {
|
|
||||||
display: flex;
|
|
||||||
align-items: center;
|
|
||||||
gap: 10px;
|
|
||||||
padding: 14px 18px;
|
|
||||||
background: rgba(255, 255, 255, 0.03);
|
|
||||||
border-bottom: 1px solid rgba(255, 255, 255, 0.06);
|
|
||||||
flex-shrink: 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
.header-icon {
|
|
||||||
width: 28px;
|
|
||||||
height: 28px;
|
|
||||||
border-radius: 8px;
|
|
||||||
background: var(--lem-accent, #5865f2);
|
|
||||||
display: flex;
|
|
||||||
align-items: center;
|
|
||||||
justify-content: center;
|
|
||||||
font-size: 14px;
|
|
||||||
font-weight: 700;
|
|
||||||
color: #fff;
|
|
||||||
}
|
|
||||||
|
|
||||||
.header-title {
|
|
||||||
font-size: 15px;
|
|
||||||
font-weight: 600;
|
|
||||||
color: var(--lem-text, #e0e0e0);
|
|
||||||
}
|
|
||||||
|
|
||||||
.header-model {
|
|
||||||
font-size: 11px;
|
|
||||||
color: rgba(255, 255, 255, 0.35);
|
|
||||||
margin-left: auto;
|
|
||||||
font-family: ui-monospace, SFMono-Regular, Menlo, monospace;
|
|
||||||
}
|
|
||||||
|
|
||||||
.header-status {
|
|
||||||
width: 8px;
|
|
||||||
height: 8px;
|
|
||||||
border-radius: 50%;
|
|
||||||
background: #43b581;
|
|
||||||
flex-shrink: 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
.header-status.disconnected {
|
|
||||||
background: #f04747;
|
|
||||||
}
|
|
||||||
`;
|
|
||||||
var messagesStyles = `
|
|
||||||
:host {
|
|
||||||
display: block;
|
|
||||||
flex: 1;
|
|
||||||
overflow-y: auto;
|
|
||||||
overflow-x: hidden;
|
|
||||||
padding: 16px 0;
|
|
||||||
scroll-behavior: smooth;
|
|
||||||
}
|
|
||||||
|
|
||||||
:host::-webkit-scrollbar {
|
|
||||||
width: 6px;
|
|
||||||
}
|
|
||||||
|
|
||||||
:host::-webkit-scrollbar-track {
|
|
||||||
background: transparent;
|
|
||||||
}
|
|
||||||
|
|
||||||
:host::-webkit-scrollbar-thumb {
|
|
||||||
background: rgba(255, 255, 255, 0.12);
|
|
||||||
border-radius: 3px;
|
|
||||||
}
|
|
||||||
|
|
||||||
.empty {
|
|
||||||
display: flex;
|
|
||||||
flex-direction: column;
|
|
||||||
align-items: center;
|
|
||||||
justify-content: center;
|
|
||||||
height: 100%;
|
|
||||||
gap: 12px;
|
|
||||||
color: rgba(255, 255, 255, 0.25);
|
|
||||||
}
|
|
||||||
|
|
||||||
.empty-icon {
|
|
||||||
font-size: 36px;
|
|
||||||
opacity: 0.4;
|
|
||||||
}
|
|
||||||
|
|
||||||
.empty-text {
|
|
||||||
font-size: 14px;
|
|
||||||
}
|
|
||||||
`;
|
|
||||||
var messageStyles = `
|
|
||||||
:host {
|
|
||||||
display: block;
|
|
||||||
padding: 6px 18px;
|
|
||||||
}
|
|
||||||
|
|
||||||
:host([role="user"]) .bubble {
|
|
||||||
background: var(--lem-msg-user, #2a2a3e);
|
|
||||||
margin-left: 40px;
|
|
||||||
border-radius: 12px 12px 4px 12px;
|
|
||||||
}
|
|
||||||
|
|
||||||
:host([role="assistant"]) .bubble {
|
|
||||||
background: var(--lem-msg-assistant, #1e1e2a);
|
|
||||||
margin-right: 40px;
|
|
||||||
border-radius: 12px 12px 12px 4px;
|
|
||||||
}
|
|
||||||
|
|
||||||
.bubble {
|
|
||||||
padding: 10px 14px;
|
|
||||||
word-wrap: break-word;
|
|
||||||
overflow-wrap: break-word;
|
|
||||||
}
|
|
||||||
|
|
||||||
.role {
|
|
||||||
font-size: 11px;
|
|
||||||
font-weight: 600;
|
|
||||||
text-transform: uppercase;
|
|
||||||
letter-spacing: 0.5px;
|
|
||||||
margin-bottom: 4px;
|
|
||||||
color: rgba(255, 255, 255, 0.35);
|
|
||||||
}
|
|
||||||
|
|
||||||
:host([role="assistant"]) .role {
|
|
||||||
color: var(--lem-accent, #5865f2);
|
|
||||||
}
|
|
||||||
|
|
||||||
.content {
|
|
||||||
color: var(--lem-text, #e0e0e0);
|
|
||||||
line-height: 1.6;
|
|
||||||
}
|
|
||||||
|
|
||||||
.content p {
|
|
||||||
margin: 0 0 8px 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
.content p:last-child {
|
|
||||||
margin-bottom: 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
.content strong {
|
|
||||||
font-weight: 600;
|
|
||||||
color: #fff;
|
|
||||||
}
|
|
||||||
|
|
||||||
.content em {
|
|
||||||
font-style: italic;
|
|
||||||
color: rgba(255, 255, 255, 0.8);
|
|
||||||
}
|
|
||||||
|
|
||||||
.content code {
|
|
||||||
font-family: ui-monospace, SFMono-Regular, Menlo, monospace;
|
|
||||||
font-size: 12px;
|
|
||||||
background: rgba(0, 0, 0, 0.3);
|
|
||||||
padding: 2px 5px;
|
|
||||||
border-radius: 4px;
|
|
||||||
color: #e8a0bf;
|
|
||||||
}
|
|
||||||
|
|
||||||
.content pre {
|
|
||||||
margin: 8px 0;
|
|
||||||
padding: 12px;
|
|
||||||
background: rgba(0, 0, 0, 0.35);
|
|
||||||
border-radius: 8px;
|
|
||||||
overflow-x: auto;
|
|
||||||
border: 1px solid rgba(255, 255, 255, 0.06);
|
|
||||||
}
|
|
||||||
|
|
||||||
.content pre code {
|
|
||||||
background: none;
|
|
||||||
padding: 0;
|
|
||||||
font-size: 12px;
|
|
||||||
color: #c9d1d9;
|
|
||||||
line-height: 1.5;
|
|
||||||
}
|
|
||||||
|
|
||||||
.think-panel {
|
|
||||||
margin: 6px 0 8px;
|
|
||||||
padding: 8px 12px;
|
|
||||||
background: rgba(88, 101, 242, 0.06);
|
|
||||||
border-left: 2px solid rgba(88, 101, 242, 0.3);
|
|
||||||
border-radius: 0 6px 6px 0;
|
|
||||||
font-size: 12px;
|
|
||||||
color: rgba(255, 255, 255, 0.45);
|
|
||||||
line-height: 1.5;
|
|
||||||
max-height: 200px;
|
|
||||||
overflow-y: auto;
|
|
||||||
}
|
|
||||||
|
|
||||||
.think-panel::-webkit-scrollbar {
|
|
||||||
width: 4px;
|
|
||||||
}
|
|
||||||
|
|
||||||
.think-panel::-webkit-scrollbar-thumb {
|
|
||||||
background: rgba(255, 255, 255, 0.1);
|
|
||||||
border-radius: 2px;
|
|
||||||
}
|
|
||||||
|
|
||||||
.think-label {
|
|
||||||
font-size: 10px;
|
|
||||||
font-weight: 600;
|
|
||||||
text-transform: uppercase;
|
|
||||||
letter-spacing: 0.5px;
|
|
||||||
color: rgba(88, 101, 242, 0.5);
|
|
||||||
margin-bottom: 4px;
|
|
||||||
cursor: pointer;
|
|
||||||
user-select: none;
|
|
||||||
}
|
|
||||||
|
|
||||||
.think-label:hover {
|
|
||||||
color: rgba(88, 101, 242, 0.7);
|
|
||||||
}
|
|
||||||
|
|
||||||
.think-panel.collapsed .think-content {
|
|
||||||
display: none;
|
|
||||||
}
|
|
||||||
|
|
||||||
.cursor {
|
|
||||||
display: inline-block;
|
|
||||||
width: 7px;
|
|
||||||
height: 16px;
|
|
||||||
background: var(--lem-accent, #5865f2);
|
|
||||||
border-radius: 1px;
|
|
||||||
animation: blink 0.8s step-end infinite;
|
|
||||||
vertical-align: text-bottom;
|
|
||||||
margin-left: 2px;
|
|
||||||
}
|
|
||||||
|
|
||||||
@keyframes blink {
|
|
||||||
50% { opacity: 0; }
|
|
||||||
}
|
|
||||||
`;
|
|
||||||
var inputStyles = `
|
|
||||||
:host {
|
|
||||||
display: block;
|
|
||||||
padding: 12px 16px 16px;
|
|
||||||
border-top: 1px solid rgba(255, 255, 255, 0.06);
|
|
||||||
flex-shrink: 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
.input-wrapper {
|
|
||||||
display: flex;
|
|
||||||
align-items: flex-end;
|
|
||||||
gap: 10px;
|
|
||||||
background: rgba(255, 255, 255, 0.05);
|
|
||||||
border: 1px solid rgba(255, 255, 255, 0.08);
|
|
||||||
border-radius: 12px;
|
|
||||||
padding: 8px 12px;
|
|
||||||
transition: border-color 0.15s;
|
|
||||||
}
|
|
||||||
|
|
||||||
.input-wrapper:focus-within {
|
|
||||||
border-color: var(--lem-accent, #5865f2);
|
|
||||||
}
|
|
||||||
|
|
||||||
textarea {
|
|
||||||
flex: 1;
|
|
||||||
background: none;
|
|
||||||
border: none;
|
|
||||||
outline: none;
|
|
||||||
color: var(--lem-text, #e0e0e0);
|
|
||||||
font-family: inherit;
|
|
||||||
font-size: 14px;
|
|
||||||
line-height: 1.5;
|
|
||||||
resize: none;
|
|
||||||
max-height: 120px;
|
|
||||||
min-height: 22px;
|
|
||||||
padding: 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
textarea::placeholder {
|
|
||||||
color: rgba(255, 255, 255, 0.25);
|
|
||||||
}
|
|
||||||
|
|
||||||
.send-btn {
|
|
||||||
background: var(--lem-accent, #5865f2);
|
|
||||||
border: none;
|
|
||||||
border-radius: 8px;
|
|
||||||
color: #fff;
|
|
||||||
width: 32px;
|
|
||||||
height: 32px;
|
|
||||||
cursor: pointer;
|
|
||||||
display: flex;
|
|
||||||
align-items: center;
|
|
||||||
justify-content: center;
|
|
||||||
flex-shrink: 0;
|
|
||||||
transition: opacity 0.15s, transform 0.1s;
|
|
||||||
}
|
|
||||||
|
|
||||||
.send-btn:hover {
|
|
||||||
opacity: 0.85;
|
|
||||||
}
|
|
||||||
|
|
||||||
.send-btn:active {
|
|
||||||
transform: scale(0.95);
|
|
||||||
}
|
|
||||||
|
|
||||||
.send-btn:disabled {
|
|
||||||
opacity: 0.3;
|
|
||||||
cursor: default;
|
|
||||||
transform: none;
|
|
||||||
}
|
|
||||||
|
|
||||||
.send-btn svg {
|
|
||||||
width: 16px;
|
|
||||||
height: 16px;
|
|
||||||
}
|
|
||||||
`;
|
|
||||||
|
|
||||||
// src/lem-messages.ts
|
|
||||||
var LemMessages = class extends HTMLElement {
|
|
||||||
shadow;
|
|
||||||
container;
|
|
||||||
emptyEl;
|
|
||||||
shouldAutoScroll = true;
|
|
||||||
constructor() {
|
|
||||||
super();
|
|
||||||
this.shadow = this.attachShadow({ mode: "open" });
|
|
||||||
}
|
|
||||||
connectedCallback() {
|
|
||||||
const style = document.createElement("style");
|
|
||||||
style.textContent = messagesStyles;
|
|
||||||
this.container = document.createElement("div");
|
|
||||||
this.emptyEl = document.createElement("div");
|
|
||||||
this.emptyEl.className = "empty";
|
|
||||||
const emptyIcon = document.createElement("div");
|
|
||||||
emptyIcon.className = "empty-icon";
|
|
||||||
emptyIcon.textContent = "\u2728";
|
|
||||||
const emptyText = document.createElement("div");
|
|
||||||
emptyText.className = "empty-text";
|
|
||||||
emptyText.textContent = "Start a conversation";
|
|
||||||
this.emptyEl.appendChild(emptyIcon);
|
|
||||||
this.emptyEl.appendChild(emptyText);
|
|
||||||
this.shadow.appendChild(style);
|
|
||||||
this.shadow.appendChild(this.emptyEl);
|
|
||||||
this.shadow.appendChild(this.container);
|
|
||||||
this.addEventListener("scroll", () => {
|
|
||||||
const threshold = 60;
|
|
||||||
this.shouldAutoScroll = this.scrollHeight - this.scrollTop - this.clientHeight < threshold;
|
|
||||||
});
|
|
||||||
}
|
|
||||||
addMessage(role, text) {
|
|
||||||
this.emptyEl.style.display = "none";
|
|
||||||
const msg = document.createElement("lem-message");
|
|
||||||
msg.setAttribute("role", role);
|
|
||||||
this.container.appendChild(msg);
|
|
||||||
if (text) {
|
|
||||||
msg.text = text;
|
|
||||||
}
|
|
||||||
this.scrollToBottom();
|
|
||||||
return msg;
|
|
||||||
}
|
|
||||||
scrollToBottom() {
|
|
||||||
if (this.shouldAutoScroll) {
|
|
||||||
requestAnimationFrame(() => {
|
|
||||||
this.scrollTop = this.scrollHeight;
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
clear() {
|
|
||||||
this.container.replaceChildren();
|
|
||||||
this.emptyEl.style.display = "";
|
|
||||||
this.shouldAutoScroll = true;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
customElements.define("lem-messages", LemMessages);
|
|
||||||
|
|
||||||
// src/markdown.ts
|
|
||||||
function escapeHtml(text) {
|
|
||||||
return text.replace(/&/g, "&").replace(/</g, "<").replace(/>/g, ">").replace(/"/g, """);
|
|
||||||
}
|
|
||||||
function parseInline(text) {
|
|
||||||
let result = escapeHtml(text);
|
|
||||||
result = result.replace(/`([^`]+)`/g, "<code>$1</code>");
|
|
||||||
result = result.replace(/\*\*(.+?)\*\*/g, "<strong>$1</strong>");
|
|
||||||
result = result.replace(/__(.+?)__/g, "<strong>$1</strong>");
|
|
||||||
result = result.replace(/(?<!\w)\*([^*]+)\*(?!\w)/g, "<em>$1</em>");
|
|
||||||
result = result.replace(/(?<!\w)_([^_]+)_(?!\w)/g, "<em>$1</em>");
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
function renderMarkdown(text) {
|
|
||||||
const lines = text.split("\n");
|
|
||||||
const output = [];
|
|
||||||
let inCodeBlock = false;
|
|
||||||
let codeLines = [];
|
|
||||||
let codeLang = "";
|
|
||||||
for (const line of lines) {
|
|
||||||
if (line.trimStart().startsWith("```")) {
|
|
||||||
if (!inCodeBlock) {
|
|
||||||
inCodeBlock = true;
|
|
||||||
codeLang = line.trimStart().slice(3).trim();
|
|
||||||
codeLines = [];
|
|
||||||
} else {
|
|
||||||
const langAttr = codeLang ? ` data-lang="${escapeHtml(codeLang)}"` : "";
|
|
||||||
output.push(
|
|
||||||
`<pre${langAttr}><code>${escapeHtml(codeLines.join("\n"))}</code></pre>`
|
|
||||||
);
|
|
||||||
inCodeBlock = false;
|
|
||||||
codeLines = [];
|
|
||||||
codeLang = "";
|
|
||||||
}
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (inCodeBlock) {
|
|
||||||
codeLines.push(line);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (line.trim() === "") {
|
|
||||||
output.push("");
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
output.push(parseInline(line));
|
|
||||||
}
|
|
||||||
if (inCodeBlock) {
|
|
||||||
const langAttr = codeLang ? ` data-lang="${escapeHtml(codeLang)}"` : "";
|
|
||||||
output.push(
|
|
||||||
`<pre${langAttr}><code>${escapeHtml(codeLines.join("\n"))}</code></pre>`
|
|
||||||
);
|
|
||||||
}
|
|
||||||
const paragraphs = [];
|
|
||||||
let current = [];
|
|
||||||
for (const line of output) {
|
|
||||||
if (line === "") {
|
|
||||||
if (current.length > 0) {
|
|
||||||
paragraphs.push(wrapParagraph(current));
|
|
||||||
current = [];
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
current.push(line);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (current.length > 0) {
|
|
||||||
paragraphs.push(wrapParagraph(current));
|
|
||||||
}
|
|
||||||
return paragraphs.join("");
|
|
||||||
}
|
|
||||||
function wrapParagraph(lines) {
|
|
||||||
const joined = lines.join("<br>");
|
|
||||||
if (joined.startsWith("<pre")) return joined;
|
|
||||||
return `<p>${joined}</p>`;
|
|
||||||
}
|
|
||||||
|
|
||||||
// src/lem-message.ts
|
|
||||||
var LemMessage = class extends HTMLElement {
|
|
||||||
shadow;
|
|
||||||
thinkPanel;
|
|
||||||
thinkContent;
|
|
||||||
thinkLabel;
|
|
||||||
contentEl;
|
|
||||||
cursorEl;
|
|
||||||
_text = "";
|
|
||||||
_streaming = false;
|
|
||||||
_thinkCollapsed = false;
|
|
||||||
constructor() {
|
|
||||||
super();
|
|
||||||
this.shadow = this.attachShadow({ mode: "open" });
|
|
||||||
}
|
|
||||||
connectedCallback() {
|
|
||||||
const role = this.getAttribute("role") || "user";
|
|
||||||
const style = document.createElement("style");
|
|
||||||
style.textContent = messageStyles;
|
|
||||||
const bubble = document.createElement("div");
|
|
||||||
bubble.className = "bubble";
|
|
||||||
const roleEl = document.createElement("div");
|
|
||||||
roleEl.className = "role";
|
|
||||||
roleEl.textContent = role === "assistant" ? "LEM" : "You";
|
|
||||||
this.thinkPanel = document.createElement("div");
|
|
||||||
this.thinkPanel.className = "think-panel";
|
|
||||||
this.thinkPanel.style.display = "none";
|
|
||||||
this.thinkLabel = document.createElement("div");
|
|
||||||
this.thinkLabel.className = "think-label";
|
|
||||||
this.thinkLabel.textContent = "\u25BC reasoning";
|
|
||||||
this.thinkLabel.addEventListener("click", () => {
|
|
||||||
this._thinkCollapsed = !this._thinkCollapsed;
|
|
||||||
this.thinkPanel.classList.toggle("collapsed", this._thinkCollapsed);
|
|
||||||
this.thinkLabel.textContent = this._thinkCollapsed ? "\u25B6 reasoning" : "\u25BC reasoning";
|
|
||||||
});
|
|
||||||
this.thinkContent = document.createElement("div");
|
|
||||||
this.thinkContent.className = "think-content";
|
|
||||||
this.thinkPanel.appendChild(this.thinkLabel);
|
|
||||||
this.thinkPanel.appendChild(this.thinkContent);
|
|
||||||
this.contentEl = document.createElement("div");
|
|
||||||
this.contentEl.className = "content";
|
|
||||||
bubble.appendChild(roleEl);
|
|
||||||
if (role === "assistant") {
|
|
||||||
bubble.appendChild(this.thinkPanel);
|
|
||||||
}
|
|
||||||
bubble.appendChild(this.contentEl);
|
|
||||||
this.shadow.appendChild(style);
|
|
||||||
this.shadow.appendChild(bubble);
|
|
||||||
if (this._text) {
|
|
||||||
this.render();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
get text() {
|
|
||||||
return this._text;
|
|
||||||
}
|
|
||||||
set text(value) {
|
|
||||||
this._text = value;
|
|
||||||
this.render();
|
|
||||||
}
|
|
||||||
get streaming() {
|
|
||||||
return this._streaming;
|
|
||||||
}
|
|
||||||
set streaming(value) {
|
|
||||||
this._streaming = value;
|
|
||||||
this.render();
|
|
||||||
}
|
|
||||||
appendToken(token) {
|
|
||||||
this._text += token;
|
|
||||||
this.render();
|
|
||||||
}
|
|
||||||
/**
|
|
||||||
* Splits text into think/response portions and renders each.
|
|
||||||
*
|
|
||||||
* Safety: renderMarkdown() escapes all HTML entities (& < > ") before any
|
|
||||||
* inline formatting is applied. The source is the local MLX model output,
|
|
||||||
* not arbitrary user HTML. Shadow DOM provides additional isolation.
|
|
||||||
*/
|
|
||||||
render() {
|
|
||||||
if (!this.contentEl) return;
|
|
||||||
const { think, response } = this.splitThink(this._text);
|
|
||||||
if (think !== null && this.thinkPanel) {
|
|
||||||
this.thinkPanel.style.display = "";
|
|
||||||
this.thinkContent.textContent = think;
|
|
||||||
}
|
|
||||||
const responseHtml = renderMarkdown(response);
|
|
||||||
this.contentEl.innerHTML = responseHtml;
|
|
||||||
if (this._streaming) {
|
|
||||||
if (!this.cursorEl) {
|
|
||||||
this.cursorEl = document.createElement("span");
|
|
||||||
this.cursorEl.className = "cursor";
|
|
||||||
}
|
|
||||||
if (think !== null && !this._text.includes("</think>")) {
|
|
||||||
this.thinkContent.appendChild(this.cursorEl);
|
|
||||||
} else {
|
|
||||||
const lastChild = this.contentEl.lastElementChild || this.contentEl;
|
|
||||||
lastChild.appendChild(this.cursorEl);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
/**
|
|
||||||
* Split raw text into think content and response content.
|
|
||||||
* Returns { think: string | null, response: string }
|
|
||||||
*/
|
|
||||||
splitThink(text) {
|
|
||||||
const thinkStart = text.indexOf("<think>");
|
|
||||||
if (thinkStart === -1) {
|
|
||||||
return { think: null, response: text };
|
|
||||||
}
|
|
||||||
const afterOpen = thinkStart + "<think>".length;
|
|
||||||
const thinkEnd = text.indexOf("</think>", afterOpen);
|
|
||||||
if (thinkEnd === -1) {
|
|
||||||
return {
|
|
||||||
think: text.slice(afterOpen).trim(),
|
|
||||||
response: text.slice(0, thinkStart).trim()
|
|
||||||
};
|
|
||||||
}
|
|
||||||
const thinkText = text.slice(afterOpen, thinkEnd).trim();
|
|
||||||
const beforeThink = text.slice(0, thinkStart).trim();
|
|
||||||
const afterThink = text.slice(thinkEnd + "</think>".length).trim();
|
|
||||||
const response = [beforeThink, afterThink].filter(Boolean).join("\n");
|
|
||||||
return { think: thinkText, response };
|
|
||||||
}
|
|
||||||
};
|
|
||||||
customElements.define("lem-message", LemMessage);
|
|
||||||
|
|
||||||
// src/lem-input.ts
|
|
||||||
var LemInput = class extends HTMLElement {
|
|
||||||
shadow;
|
|
||||||
textarea;
|
|
||||||
sendBtn;
|
|
||||||
_disabled = false;
|
|
||||||
constructor() {
|
|
||||||
super();
|
|
||||||
this.shadow = this.attachShadow({ mode: "open" });
|
|
||||||
}
|
|
||||||
connectedCallback() {
|
|
||||||
const style = document.createElement("style");
|
|
||||||
style.textContent = inputStyles;
|
|
||||||
const wrapper = document.createElement("div");
|
|
||||||
wrapper.className = "input-wrapper";
|
|
||||||
this.textarea = document.createElement("textarea");
|
|
||||||
this.textarea.rows = 1;
|
|
||||||
this.textarea.placeholder = "Message LEM...";
|
|
||||||
this.sendBtn = document.createElement("button");
|
|
||||||
this.sendBtn.className = "send-btn";
|
|
||||||
this.sendBtn.type = "button";
|
|
||||||
this.sendBtn.disabled = true;
|
|
||||||
this.sendBtn.appendChild(this.createSendIcon());
|
|
||||||
wrapper.appendChild(this.textarea);
|
|
||||||
wrapper.appendChild(this.sendBtn);
|
|
||||||
this.shadow.appendChild(style);
|
|
||||||
this.shadow.appendChild(wrapper);
|
|
||||||
this.textarea.addEventListener("input", () => {
|
|
||||||
this.textarea.style.height = "auto";
|
|
||||||
this.textarea.style.height = Math.min(this.textarea.scrollHeight, 120) + "px";
|
|
||||||
this.sendBtn.disabled = this._disabled || this.textarea.value.trim() === "";
|
|
||||||
});
|
|
||||||
this.textarea.addEventListener("keydown", (e) => {
|
|
||||||
if (e.key === "Enter" && !e.shiftKey) {
|
|
||||||
e.preventDefault();
|
|
||||||
this.submit();
|
|
||||||
}
|
|
||||||
});
|
|
||||||
this.sendBtn.addEventListener("click", () => this.submit());
|
|
||||||
}
|
|
||||||
/** Build the send arrow SVG using DOM API (no innerHTML) */
|
|
||||||
createSendIcon() {
|
|
||||||
const ns = "http://www.w3.org/2000/svg";
|
|
||||||
const svg = document.createElementNS(ns, "svg");
|
|
||||||
svg.setAttribute("viewBox", "0 0 24 24");
|
|
||||||
svg.setAttribute("fill", "none");
|
|
||||||
svg.setAttribute("stroke", "currentColor");
|
|
||||||
svg.setAttribute("stroke-width", "2");
|
|
||||||
svg.setAttribute("stroke-linecap", "round");
|
|
||||||
svg.setAttribute("stroke-linejoin", "round");
|
|
||||||
svg.setAttribute("width", "16");
|
|
||||||
svg.setAttribute("height", "16");
|
|
||||||
const line = document.createElementNS(ns, "line");
|
|
||||||
line.setAttribute("x1", "22");
|
|
||||||
line.setAttribute("y1", "2");
|
|
||||||
line.setAttribute("x2", "11");
|
|
||||||
line.setAttribute("y2", "13");
|
|
||||||
const polygon = document.createElementNS(ns, "polygon");
|
|
||||||
polygon.setAttribute("points", "22 2 15 22 11 13 2 9 22 2");
|
|
||||||
svg.appendChild(line);
|
|
||||||
svg.appendChild(polygon);
|
|
||||||
return svg;
|
|
||||||
}
|
|
||||||
submit() {
|
|
||||||
const text = this.textarea.value.trim();
|
|
||||||
if (!text || this._disabled) return;
|
|
||||||
this.dispatchEvent(
|
|
||||||
new CustomEvent("lem-send", {
|
|
||||||
bubbles: true,
|
|
||||||
composed: true,
|
|
||||||
detail: { text }
|
|
||||||
})
|
|
||||||
);
|
|
||||||
this.textarea.value = "";
|
|
||||||
this.textarea.style.height = "auto";
|
|
||||||
this.sendBtn.disabled = true;
|
|
||||||
this.textarea.focus();
|
|
||||||
}
|
|
||||||
get disabled() {
|
|
||||||
return this._disabled;
|
|
||||||
}
|
|
||||||
set disabled(value) {
|
|
||||||
this._disabled = value;
|
|
||||||
this.textarea.disabled = value;
|
|
||||||
this.sendBtn.disabled = value || this.textarea.value.trim() === "";
|
|
||||||
this.textarea.placeholder = value ? "LEM is thinking..." : "Message LEM...";
|
|
||||||
}
|
|
||||||
focus() {
|
|
||||||
this.textarea?.focus();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
customElements.define("lem-input", LemInput);
|
|
||||||
|
|
||||||
// src/lem-chat.ts
|
|
||||||
var LemChat = class extends HTMLElement {
|
|
||||||
shadow;
|
|
||||||
messages;
|
|
||||||
input;
|
|
||||||
statusEl;
|
|
||||||
history = [];
|
|
||||||
abortController = null;
|
|
||||||
static get observedAttributes() {
|
|
||||||
return ["endpoint", "model", "system-prompt", "max-tokens", "temperature"];
|
|
||||||
}
|
|
||||||
constructor() {
|
|
||||||
super();
|
|
||||||
this.shadow = this.attachShadow({ mode: "open" });
|
|
||||||
}
|
|
||||||
connectedCallback() {
|
|
||||||
const style = document.createElement("style");
|
|
||||||
style.textContent = chatStyles;
|
|
||||||
const header = document.createElement("div");
|
|
||||||
header.className = "header";
|
|
||||||
this.statusEl = document.createElement("div");
|
|
||||||
this.statusEl.className = "header-status";
|
|
||||||
const icon = document.createElement("div");
|
|
||||||
icon.className = "header-icon";
|
|
||||||
icon.textContent = "L";
|
|
||||||
const title = document.createElement("div");
|
|
||||||
title.className = "header-title";
|
|
||||||
title.textContent = "LEM";
|
|
||||||
const modelLabel = document.createElement("div");
|
|
||||||
modelLabel.className = "header-model";
|
|
||||||
modelLabel.textContent = this.getAttribute("model") || "local";
|
|
||||||
header.appendChild(this.statusEl);
|
|
||||||
header.appendChild(icon);
|
|
||||||
header.appendChild(title);
|
|
||||||
header.appendChild(modelLabel);
|
|
||||||
this.messages = document.createElement("lem-messages");
|
|
||||||
this.input = document.createElement("lem-input");
|
|
||||||
this.shadow.appendChild(style);
|
|
||||||
this.shadow.appendChild(header);
|
|
||||||
this.shadow.appendChild(this.messages);
|
|
||||||
this.shadow.appendChild(this.input);
|
|
||||||
this.addEventListener("lem-send", ((e) => {
|
|
||||||
this.handleSend(e.detail.text);
|
|
||||||
}));
|
|
||||||
const systemPrompt = this.getAttribute("system-prompt");
|
|
||||||
if (systemPrompt) {
|
|
||||||
this.history.push({ role: "system", content: systemPrompt });
|
|
||||||
}
|
|
||||||
this.checkConnection();
|
|
||||||
requestAnimationFrame(() => this.input.focus());
|
|
||||||
}
|
|
||||||
disconnectedCallback() {
|
|
||||||
this.abortController?.abort();
|
|
||||||
}
|
|
||||||
get endpoint() {
|
|
||||||
const attr = this.getAttribute("endpoint");
|
|
||||||
if (!attr) return window.location.origin;
|
|
||||||
return attr;
|
|
||||||
}
|
|
||||||
get model() {
|
|
||||||
return this.getAttribute("model") || "";
|
|
||||||
}
|
|
||||||
get maxTokens() {
|
|
||||||
const val = this.getAttribute("max-tokens");
|
|
||||||
return val ? parseInt(val, 10) : 2048;
|
|
||||||
}
|
|
||||||
get temperature() {
|
|
||||||
const val = this.getAttribute("temperature");
|
|
||||||
return val ? parseFloat(val) : 0.7;
|
|
||||||
}
|
|
||||||
async checkConnection() {
|
|
||||||
try {
|
|
||||||
const resp = await fetch(`${this.endpoint}/v1/models`, {
|
|
||||||
signal: AbortSignal.timeout(3e3)
|
|
||||||
});
|
|
||||||
this.statusEl.classList.toggle("disconnected", !resp.ok);
|
|
||||||
} catch {
|
|
||||||
this.statusEl.classList.add("disconnected");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
async handleSend(text) {
|
|
||||||
this.messages.addMessage("user", text);
|
|
||||||
this.history.push({ role: "user", content: text });
|
|
||||||
const assistantMsg = this.messages.addMessage("assistant");
|
|
||||||
assistantMsg.streaming = true;
|
|
||||||
this.input.disabled = true;
|
|
||||||
this.abortController?.abort();
|
|
||||||
this.abortController = new AbortController();
|
|
||||||
let fullResponse = "";
|
|
||||||
try {
|
|
||||||
const response = await fetch(`${this.endpoint}/v1/chat/completions`, {
|
|
||||||
method: "POST",
|
|
||||||
headers: { "Content-Type": "application/json" },
|
|
||||||
signal: this.abortController.signal,
|
|
||||||
body: JSON.stringify({
|
|
||||||
model: this.model,
|
|
||||||
messages: this.history,
|
|
||||||
max_tokens: this.maxTokens,
|
|
||||||
temperature: this.temperature,
|
|
||||||
stream: true
|
|
||||||
})
|
|
||||||
});
|
|
||||||
if (!response.ok) {
|
|
||||||
throw new Error(`Server error: ${response.status}`);
|
|
||||||
}
|
|
||||||
if (!response.body) {
|
|
||||||
throw new Error("No response body");
|
|
||||||
}
|
|
||||||
const reader = response.body.getReader();
|
|
||||||
const decoder = new TextDecoder();
|
|
||||||
let buffer = "";
|
|
||||||
while (true) {
|
|
||||||
const { done, value } = await reader.read();
|
|
||||||
if (done) break;
|
|
||||||
buffer += decoder.decode(value, { stream: true });
|
|
||||||
const lines = buffer.split("\n");
|
|
||||||
buffer = lines.pop() || "";
|
|
||||||
for (const line of lines) {
|
|
||||||
if (!line.startsWith("data: ")) continue;
|
|
||||||
const data = line.slice(6).trim();
|
|
||||||
if (data === "[DONE]") continue;
|
|
||||||
try {
|
|
||||||
const chunk = JSON.parse(data);
|
|
||||||
const delta = chunk.choices?.[0]?.delta;
|
|
||||||
if (delta?.content) {
|
|
||||||
fullResponse += delta.content;
|
|
||||||
assistantMsg.appendToken(delta.content);
|
|
||||||
this.messages.scrollToBottom();
|
|
||||||
}
|
|
||||||
} catch {
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} catch (err) {
|
|
||||||
if (err instanceof Error && err.name === "AbortError") {
|
|
||||||
} else {
|
|
||||||
const errorText = err instanceof Error ? err.message : "Connection failed";
|
|
||||||
if (!fullResponse) {
|
|
||||||
assistantMsg.text = `\u26A0\uFE0F ${errorText}`;
|
|
||||||
}
|
|
||||||
this.statusEl.classList.add("disconnected");
|
|
||||||
}
|
|
||||||
} finally {
|
|
||||||
assistantMsg.streaming = false;
|
|
||||||
this.input.disabled = false;
|
|
||||||
this.input.focus();
|
|
||||||
this.abortController = null;
|
|
||||||
if (fullResponse) {
|
|
||||||
this.history.push({ role: "assistant", content: fullResponse });
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
customElements.define("lem-chat", LemChat);
|
|
||||||
export {
|
|
||||||
LemChat
|
|
||||||
};
|
|
||||||
|
|
@ -1,44 +0,0 @@
|
||||||
package ml
|
|
||||||
|
|
||||||
import (
|
|
||||||
_ "embed"
|
|
||||||
)
|
|
||||||
|
|
||||||
//go:embed chat.js
|
|
||||||
var lemChatJS []byte
|
|
||||||
|
|
||||||
const chatHTML = `<!DOCTYPE html>
|
|
||||||
<html lang="en">
|
|
||||||
<head>
|
|
||||||
<meta charset="UTF-8">
|
|
||||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
|
||||||
<title>LEM Chat</title>
|
|
||||||
<style>
|
|
||||||
* { margin: 0; padding: 0; box-sizing: border-box; }
|
|
||||||
html, body { height: 100%%; background: #111; }
|
|
||||||
body {
|
|
||||||
display: flex;
|
|
||||||
align-items: center;
|
|
||||||
justify-content: center;
|
|
||||||
font-family: system-ui, -apple-system, sans-serif;
|
|
||||||
}
|
|
||||||
lem-chat {
|
|
||||||
width: 720px;
|
|
||||||
height: 85vh;
|
|
||||||
max-height: 800px;
|
|
||||||
}
|
|
||||||
@media (max-width: 768px) {
|
|
||||||
lem-chat { width: 100%%; height: 100%%; max-height: none; border-radius: 0; }
|
|
||||||
}
|
|
||||||
</style>
|
|
||||||
</head>
|
|
||||||
<body>
|
|
||||||
<lem-chat
|
|
||||||
endpoint=""
|
|
||||||
model="%s"
|
|
||||||
system-prompt=""
|
|
||||||
max-tokens="%d"
|
|
||||||
></lem-chat>
|
|
||||||
<script type="module" src="/chat.js"></script>
|
|
||||||
</body>
|
|
||||||
</html>`
|
|
||||||
|
|
@ -2,7 +2,7 @@ package ml
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"forge.lthn.ai/core/go/pkg/cli"
|
"forge.lthn.ai/core/go/pkg/cli"
|
||||||
"forge.lthn.ai/core/go-ai/ml"
|
"forge.lthn.ai/core/go/pkg/ml"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ import (
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
|
||||||
"forge.lthn.ai/core/go/pkg/cli"
|
"forge.lthn.ai/core/go/pkg/cli"
|
||||||
"forge.lthn.ai/core/go-ai/ml"
|
"forge.lthn.ai/core/go/pkg/ml"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
|
|
||||||
|
|
@ -1,301 +0,0 @@
|
||||||
//go:build darwin && arm64
|
|
||||||
|
|
||||||
package ml
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"log/slog"
|
|
||||||
"os"
|
|
||||||
"runtime"
|
|
||||||
"sort"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"forge.lthn.ai/core/go-ai/ml"
|
|
||||||
"forge.lthn.ai/core/go/pkg/cli"
|
|
||||||
)
|
|
||||||
|
|
||||||
var benchmarkCmd = &cli.Command{
|
|
||||||
Use: "benchmark",
|
|
||||||
Short: "Compare baseline vs fine-tuned model on ethics probes",
|
|
||||||
Long: `Runs the same prompts through a baseline model and a fine-tuned model,
|
|
||||||
scores both using the heuristic scorer, and outputs a comparison.
|
|
||||||
|
|
||||||
Uses the built-in LEK content probes by default. Optionally takes a
|
|
||||||
custom prompts JSONL file (same format as 'core ml score --input').
|
|
||||||
|
|
||||||
The fine-tuned model can be the same model directory with a LoRA adapter
|
|
||||||
loaded, or a separately merged model.`,
|
|
||||||
RunE: runBenchmark,
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
|
||||||
benchmarkBaseline string
|
|
||||||
benchmarkTrained string
|
|
||||||
benchmarkPrompts string
|
|
||||||
benchmarkOutput string
|
|
||||||
benchmarkMaxTokens int
|
|
||||||
benchmarkTemp float64
|
|
||||||
benchmarkMemLimit int
|
|
||||||
)
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
benchmarkCmd.Flags().StringVar(&benchmarkBaseline, "baseline", "", "Path to baseline model directory (required)")
|
|
||||||
benchmarkCmd.Flags().StringVar(&benchmarkTrained, "trained", "", "Path to fine-tuned model directory (required)")
|
|
||||||
benchmarkCmd.Flags().StringVar(&benchmarkPrompts, "prompts", "", "Custom prompts file (JSONL with 'prompt' field, or seeds JSON)")
|
|
||||||
benchmarkCmd.Flags().StringVar(&benchmarkOutput, "output", "benchmark.json", "Output comparison JSON file")
|
|
||||||
benchmarkCmd.Flags().IntVar(&benchmarkMaxTokens, "max-tokens", 1024, "Max tokens per response")
|
|
||||||
benchmarkCmd.Flags().Float64Var(&benchmarkTemp, "temperature", 0.4, "Sampling temperature")
|
|
||||||
benchmarkCmd.Flags().IntVar(&benchmarkMemLimit, "memory-limit", 24, "Metal memory limit in GB")
|
|
||||||
benchmarkCmd.MarkFlagRequired("baseline")
|
|
||||||
benchmarkCmd.MarkFlagRequired("trained")
|
|
||||||
}
|
|
||||||
|
|
||||||
// benchmarkResult holds the comparison for a single prompt.
|
|
||||||
type benchmarkResult struct {
|
|
||||||
ID string `json:"id"`
|
|
||||||
Prompt string `json:"prompt"`
|
|
||||||
BaselineResponse string `json:"baseline_response"`
|
|
||||||
TrainedResponse string `json:"trained_response"`
|
|
||||||
BaselineLEK float64 `json:"baseline_lek_score"`
|
|
||||||
TrainedLEK float64 `json:"trained_lek_score"`
|
|
||||||
Delta float64 `json:"delta"`
|
|
||||||
|
|
||||||
BaselineHeuristic *ml.HeuristicScores `json:"baseline_heuristic"`
|
|
||||||
TrainedHeuristic *ml.HeuristicScores `json:"trained_heuristic"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// benchmarkSummary holds aggregate comparison metrics.
|
|
||||||
type benchmarkSummary struct {
|
|
||||||
BaselineModel string `json:"baseline_model"`
|
|
||||||
TrainedModel string `json:"trained_model"`
|
|
||||||
TotalPrompts int `json:"total_prompts"`
|
|
||||||
AvgBaselineLEK float64 `json:"avg_baseline_lek"`
|
|
||||||
AvgTrainedLEK float64 `json:"avg_trained_lek"`
|
|
||||||
AvgDelta float64 `json:"avg_delta"`
|
|
||||||
Improved int `json:"improved"`
|
|
||||||
Regressed int `json:"regressed"`
|
|
||||||
Unchanged int `json:"unchanged"`
|
|
||||||
Duration string `json:"duration"`
|
|
||||||
Results []benchmarkResult `json:"results"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func runBenchmark(cmd *cli.Command, args []string) error {
|
|
||||||
start := time.Now()
|
|
||||||
|
|
||||||
// Load prompts — either custom file or built-in probes
|
|
||||||
prompts, err := loadBenchmarkPrompts()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
slog.Info("benchmark: loaded prompts", "count", len(prompts))
|
|
||||||
|
|
||||||
opts := ml.GenOpts{
|
|
||||||
Temperature: benchmarkTemp,
|
|
||||||
MaxTokens: benchmarkMaxTokens,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Generate baseline responses
|
|
||||||
slog.Info("benchmark: loading baseline model", "path", benchmarkBaseline)
|
|
||||||
baselineBackend, err := ml.NewMLXBackend(benchmarkBaseline)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("load baseline: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
baselineResponses := make(map[string]string)
|
|
||||||
for i, p := range prompts {
|
|
||||||
slog.Info("benchmark: baseline",
|
|
||||||
"prompt", fmt.Sprintf("%d/%d", i+1, len(prompts)),
|
|
||||||
"id", p.id,
|
|
||||||
)
|
|
||||||
resp, err := baselineBackend.Generate(context.Background(), p.prompt, opts)
|
|
||||||
if err != nil {
|
|
||||||
slog.Error("benchmark: baseline failed", "id", p.id, "error", err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
baselineResponses[p.id] = resp
|
|
||||||
|
|
||||||
if (i+1)%4 == 0 {
|
|
||||||
runtime.GC()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Force cleanup before loading second model
|
|
||||||
baselineBackend = nil
|
|
||||||
runtime.GC()
|
|
||||||
runtime.GC()
|
|
||||||
|
|
||||||
// Generate trained responses
|
|
||||||
slog.Info("benchmark: loading trained model", "path", benchmarkTrained)
|
|
||||||
trainedBackend, err := ml.NewMLXBackend(benchmarkTrained)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("load trained: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
trainedResponses := make(map[string]string)
|
|
||||||
for i, p := range prompts {
|
|
||||||
slog.Info("benchmark: trained",
|
|
||||||
"prompt", fmt.Sprintf("%d/%d", i+1, len(prompts)),
|
|
||||||
"id", p.id,
|
|
||||||
)
|
|
||||||
resp, err := trainedBackend.Generate(context.Background(), p.prompt, opts)
|
|
||||||
if err != nil {
|
|
||||||
slog.Error("benchmark: trained failed", "id", p.id, "error", err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
trainedResponses[p.id] = resp
|
|
||||||
|
|
||||||
if (i+1)%4 == 0 {
|
|
||||||
runtime.GC()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
trainedBackend = nil
|
|
||||||
runtime.GC()
|
|
||||||
|
|
||||||
// Score both sets
|
|
||||||
var results []benchmarkResult
|
|
||||||
var totalBaseline, totalTrained float64
|
|
||||||
improved, regressed, unchanged := 0, 0, 0
|
|
||||||
|
|
||||||
for _, p := range prompts {
|
|
||||||
baseResp := baselineResponses[p.id]
|
|
||||||
trainResp := trainedResponses[p.id]
|
|
||||||
|
|
||||||
if baseResp == "" || trainResp == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
baseH := ml.ScoreHeuristic(baseResp)
|
|
||||||
trainH := ml.ScoreHeuristic(trainResp)
|
|
||||||
delta := trainH.LEKScore - baseH.LEKScore
|
|
||||||
|
|
||||||
totalBaseline += baseH.LEKScore
|
|
||||||
totalTrained += trainH.LEKScore
|
|
||||||
|
|
||||||
if delta > 0.5 {
|
|
||||||
improved++
|
|
||||||
} else if delta < -0.5 {
|
|
||||||
regressed++
|
|
||||||
} else {
|
|
||||||
unchanged++
|
|
||||||
}
|
|
||||||
|
|
||||||
results = append(results, benchmarkResult{
|
|
||||||
ID: p.id,
|
|
||||||
Prompt: p.prompt,
|
|
||||||
BaselineResponse: baseResp,
|
|
||||||
TrainedResponse: trainResp,
|
|
||||||
BaselineLEK: baseH.LEKScore,
|
|
||||||
TrainedLEK: trainH.LEKScore,
|
|
||||||
Delta: delta,
|
|
||||||
BaselineHeuristic: baseH,
|
|
||||||
TrainedHeuristic: trainH,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
n := float64(len(results))
|
|
||||||
if n == 0 {
|
|
||||||
return fmt.Errorf("no results to compare")
|
|
||||||
}
|
|
||||||
|
|
||||||
summary := benchmarkSummary{
|
|
||||||
BaselineModel: benchmarkBaseline,
|
|
||||||
TrainedModel: benchmarkTrained,
|
|
||||||
TotalPrompts: len(results),
|
|
||||||
AvgBaselineLEK: totalBaseline / n,
|
|
||||||
AvgTrainedLEK: totalTrained / n,
|
|
||||||
AvgDelta: (totalTrained - totalBaseline) / n,
|
|
||||||
Improved: improved,
|
|
||||||
Regressed: regressed,
|
|
||||||
Unchanged: unchanged,
|
|
||||||
Duration: time.Since(start).Round(time.Second).String(),
|
|
||||||
Results: results,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Write output
|
|
||||||
data, err := json.MarshalIndent(summary, "", " ")
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("marshal output: %w", err)
|
|
||||||
}
|
|
||||||
if err := os.WriteFile(benchmarkOutput, data, 0644); err != nil {
|
|
||||||
return fmt.Errorf("write output: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Print summary
|
|
||||||
fmt.Println()
|
|
||||||
fmt.Println("=== Benchmark Results ===")
|
|
||||||
fmt.Printf("Baseline: %s\n", benchmarkBaseline)
|
|
||||||
fmt.Printf("Trained: %s\n", benchmarkTrained)
|
|
||||||
fmt.Printf("Prompts: %d\n", len(results))
|
|
||||||
fmt.Println()
|
|
||||||
fmt.Printf("Avg LEK (baseline): %+.2f\n", summary.AvgBaselineLEK)
|
|
||||||
fmt.Printf("Avg LEK (trained): %+.2f\n", summary.AvgTrainedLEK)
|
|
||||||
fmt.Printf("Avg Delta: %+.2f\n", summary.AvgDelta)
|
|
||||||
fmt.Println()
|
|
||||||
fmt.Printf("Improved: %d (%.0f%%)\n", improved, float64(improved)/n*100)
|
|
||||||
fmt.Printf("Regressed: %d (%.0f%%)\n", regressed, float64(regressed)/n*100)
|
|
||||||
fmt.Printf("Unchanged: %d (%.0f%%)\n", unchanged, float64(unchanged)/n*100)
|
|
||||||
fmt.Printf("Duration: %s\n", summary.Duration)
|
|
||||||
fmt.Printf("Output: %s\n", benchmarkOutput)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type benchPrompt struct {
|
|
||||||
id string
|
|
||||||
prompt string
|
|
||||||
}
|
|
||||||
|
|
||||||
func loadBenchmarkPrompts() ([]benchPrompt, error) {
|
|
||||||
if benchmarkPrompts == "" {
|
|
||||||
// Use built-in content probes
|
|
||||||
probes := ml.ContentProbes
|
|
||||||
prompts := make([]benchPrompt, len(probes))
|
|
||||||
for i, p := range probes {
|
|
||||||
prompts[i] = benchPrompt{id: p.ID, prompt: p.Prompt}
|
|
||||||
}
|
|
||||||
return prompts, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Try seeds JSON format first (array of {id, prompt, ...})
|
|
||||||
data, err := os.ReadFile(benchmarkPrompts)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("read prompts: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var seeds []seedPrompt
|
|
||||||
if json.Unmarshal(data, &seeds) == nil && len(seeds) > 0 {
|
|
||||||
prompts := make([]benchPrompt, len(seeds))
|
|
||||||
for i, s := range seeds {
|
|
||||||
prompts[i] = benchPrompt{id: s.ID, prompt: s.Prompt}
|
|
||||||
}
|
|
||||||
return prompts, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Try JSONL responses format
|
|
||||||
responses, err := ml.ReadResponses(benchmarkPrompts)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("parse prompts: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Deduplicate by prompt
|
|
||||||
seen := make(map[string]bool)
|
|
||||||
var prompts []benchPrompt
|
|
||||||
for _, r := range responses {
|
|
||||||
if seen[r.Prompt] {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
seen[r.Prompt] = true
|
|
||||||
id := r.ID
|
|
||||||
if id == "" {
|
|
||||||
id = fmt.Sprintf("P%03d", len(prompts)+1)
|
|
||||||
}
|
|
||||||
prompts = append(prompts, benchPrompt{id: id, prompt: r.Prompt})
|
|
||||||
}
|
|
||||||
|
|
||||||
sort.Slice(prompts, func(i, j int) bool { return prompts[i].id < prompts[j].id })
|
|
||||||
return prompts, nil
|
|
||||||
}
|
|
||||||
|
|
@ -1,7 +0,0 @@
|
||||||
//go:build darwin && arm64
|
|
||||||
|
|
||||||
package ml
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
mlCmd.AddCommand(benchmarkCmd)
|
|
||||||
}
|
|
||||||
|
|
@ -1,327 +0,0 @@
|
||||||
//go:build darwin && arm64
|
|
||||||
|
|
||||||
package ml
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bufio"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"log/slog"
|
|
||||||
"os"
|
|
||||||
"runtime"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"forge.lthn.ai/core/go-ai/ml"
|
|
||||||
"forge.lthn.ai/core/go/pkg/cli"
|
|
||||||
)
|
|
||||||
|
|
||||||
var chatCmd = &cli.Command{
|
|
||||||
Use: "chat",
|
|
||||||
Short: "Interactive conversation with a local MLX model",
|
|
||||||
Long: `Start an interactive chat session with a local MLX model.
|
|
||||||
|
|
||||||
All exchanges are captured and can be written to training JSONL on exit
|
|
||||||
for use with 'core ml train'. Optionally apply axiom sandwich signing
|
|
||||||
to wrap the conversation for LEK training.
|
|
||||||
|
|
||||||
Commands during chat:
|
|
||||||
/quit, /exit End session and save
|
|
||||||
/save Save conversation so far (appends to output)
|
|
||||||
/clear Clear conversation history
|
|
||||||
/system <text> Set system prompt
|
|
||||||
/undo Remove last exchange`,
|
|
||||||
RunE: runChat,
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
|
||||||
chatModelPath string
|
|
||||||
chatOutput string
|
|
||||||
chatKB string
|
|
||||||
chatKernel string
|
|
||||||
chatSystem string
|
|
||||||
chatMaxTokens int
|
|
||||||
chatTemp float64
|
|
||||||
chatMemLimit int
|
|
||||||
)
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
chatCmd.Flags().StringVar(&chatModelPath, "model-path", "", "Path to model directory (required)")
|
|
||||||
chatCmd.Flags().StringVar(&chatOutput, "output", "", "Output JSONL file for captured conversation")
|
|
||||||
chatCmd.Flags().StringVar(&chatKB, "kb", "", "Knowledge base document for sandwich signing")
|
|
||||||
chatCmd.Flags().StringVar(&chatKernel, "kernel", "", "LEK-1 kernel file for sandwich signing")
|
|
||||||
chatCmd.Flags().StringVar(&chatSystem, "system", "", "Initial system prompt")
|
|
||||||
chatCmd.Flags().IntVar(&chatMaxTokens, "max-tokens", 2048, "Max tokens per response")
|
|
||||||
chatCmd.Flags().Float64Var(&chatTemp, "temperature", 0.4, "Sampling temperature")
|
|
||||||
chatCmd.Flags().IntVar(&chatMemLimit, "memory-limit", 24, "Metal memory limit in GB")
|
|
||||||
chatCmd.MarkFlagRequired("model-path")
|
|
||||||
}
|
|
||||||
|
|
||||||
func runChat(cmd *cli.Command, args []string) error {
|
|
||||||
// Load optional KB and kernel for sandwich signing
|
|
||||||
var kbText, kernelText string
|
|
||||||
if chatKB != "" {
|
|
||||||
data, err := os.ReadFile(chatKB)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("read KB: %w", err)
|
|
||||||
}
|
|
||||||
kbText = string(data)
|
|
||||||
}
|
|
||||||
if chatKernel != "" {
|
|
||||||
data, err := os.ReadFile(chatKernel)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("read kernel: %w", err)
|
|
||||||
}
|
|
||||||
kernelText = string(data)
|
|
||||||
}
|
|
||||||
sandwich := kbText != "" && kernelText != ""
|
|
||||||
|
|
||||||
// Load model
|
|
||||||
slog.Info("chat: loading model", "path", chatModelPath)
|
|
||||||
backend, err := ml.NewMLXBackend(chatModelPath)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("load model: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
opts := ml.GenOpts{
|
|
||||||
Temperature: chatTemp,
|
|
||||||
MaxTokens: chatMaxTokens,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Conversation state
|
|
||||||
var history []ml.Message
|
|
||||||
if chatSystem != "" {
|
|
||||||
history = append(history, ml.Message{Role: "system", Content: chatSystem})
|
|
||||||
}
|
|
||||||
|
|
||||||
// Track saved conversations for JSONL output
|
|
||||||
var savedConversations [][]ml.Message
|
|
||||||
|
|
||||||
fmt.Println("Chat started. Type /quit to exit, /help for commands.")
|
|
||||||
if sandwich {
|
|
||||||
fmt.Println("Sandwich signing enabled (KB + kernel)")
|
|
||||||
}
|
|
||||||
if chatOutput != "" {
|
|
||||||
fmt.Printf("Capturing to: %s\n", chatOutput)
|
|
||||||
}
|
|
||||||
fmt.Println()
|
|
||||||
|
|
||||||
scanner := bufio.NewScanner(os.Stdin)
|
|
||||||
scanner.Buffer(make([]byte, 1<<20), 1<<20) // 1MB input buffer
|
|
||||||
|
|
||||||
for {
|
|
||||||
fmt.Print("you> ")
|
|
||||||
if !scanner.Scan() {
|
|
||||||
// EOF (Ctrl+D)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
input := strings.TrimSpace(scanner.Text())
|
|
||||||
if input == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle commands
|
|
||||||
if strings.HasPrefix(input, "/") {
|
|
||||||
cmd := strings.Fields(input)
|
|
||||||
switch cmd[0] {
|
|
||||||
case "/quit", "/exit":
|
|
||||||
goto done
|
|
||||||
case "/save":
|
|
||||||
if chatOutput == "" {
|
|
||||||
fmt.Println("No --output file specified. Use --output to enable saving.")
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if len(history) > 0 {
|
|
||||||
savedConversations = append(savedConversations, cloneMessages(history))
|
|
||||||
fmt.Printf("Saved conversation (%d messages)\n", len(history))
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
case "/clear":
|
|
||||||
sysPrompt := ""
|
|
||||||
for _, m := range history {
|
|
||||||
if m.Role == "system" {
|
|
||||||
sysPrompt = m.Content
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
history = nil
|
|
||||||
if sysPrompt != "" {
|
|
||||||
history = append(history, ml.Message{Role: "system", Content: sysPrompt})
|
|
||||||
}
|
|
||||||
fmt.Println("Conversation cleared.")
|
|
||||||
continue
|
|
||||||
case "/system":
|
|
||||||
if len(cmd) < 2 {
|
|
||||||
fmt.Println("Usage: /system <prompt text>")
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
sysText := strings.TrimPrefix(input, "/system ")
|
|
||||||
// Replace existing system prompt or add new one
|
|
||||||
found := false
|
|
||||||
for i, m := range history {
|
|
||||||
if m.Role == "system" {
|
|
||||||
history[i].Content = sysText
|
|
||||||
found = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !found {
|
|
||||||
// Prepend system message
|
|
||||||
history = append([]ml.Message{{Role: "system", Content: sysText}}, history...)
|
|
||||||
}
|
|
||||||
fmt.Printf("System prompt set (%d chars)\n", len(sysText))
|
|
||||||
continue
|
|
||||||
case "/undo":
|
|
||||||
// Remove last user+assistant pair
|
|
||||||
if len(history) >= 2 {
|
|
||||||
last := history[len(history)-1]
|
|
||||||
secondLast := history[len(history)-2]
|
|
||||||
if secondLast.Role == "user" && last.Role == "assistant" {
|
|
||||||
history = history[:len(history)-2]
|
|
||||||
fmt.Println("Last exchange removed.")
|
|
||||||
} else {
|
|
||||||
fmt.Println("Cannot undo: last messages are not a user/assistant pair.")
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
fmt.Println("Nothing to undo.")
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
case "/help":
|
|
||||||
fmt.Println("Commands:")
|
|
||||||
fmt.Println(" /quit, /exit End session and save")
|
|
||||||
fmt.Println(" /save Save conversation so far")
|
|
||||||
fmt.Println(" /clear Clear conversation history")
|
|
||||||
fmt.Println(" /system <text> Set system prompt")
|
|
||||||
fmt.Println(" /undo Remove last exchange")
|
|
||||||
fmt.Println(" /help Show this help")
|
|
||||||
continue
|
|
||||||
default:
|
|
||||||
fmt.Printf("Unknown command: %s (try /help)\n", cmd[0])
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add user message
|
|
||||||
history = append(history, ml.Message{Role: "user", Content: input})
|
|
||||||
|
|
||||||
// Generate response
|
|
||||||
genStart := time.Now()
|
|
||||||
fmt.Print("\nassistant> ")
|
|
||||||
|
|
||||||
var response strings.Builder
|
|
||||||
err := backend.ChatStream(cmd.Context(), history, opts, func(token string) error {
|
|
||||||
fmt.Print(token)
|
|
||||||
response.WriteString(token)
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
fmt.Println()
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
slog.Error("chat: generation failed", "error", err)
|
|
||||||
// Remove the failed user message
|
|
||||||
history = history[:len(history)-1]
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
elapsed := time.Since(genStart)
|
|
||||||
responseText := response.String()
|
|
||||||
history = append(history, ml.Message{Role: "assistant", Content: responseText})
|
|
||||||
|
|
||||||
slog.Debug("chat: response generated",
|
|
||||||
"chars", len(responseText),
|
|
||||||
"duration", elapsed.Round(time.Millisecond),
|
|
||||||
)
|
|
||||||
|
|
||||||
// Periodic cleanup
|
|
||||||
if len(history)%8 == 0 {
|
|
||||||
runtime.GC()
|
|
||||||
}
|
|
||||||
|
|
||||||
fmt.Println()
|
|
||||||
}
|
|
||||||
|
|
||||||
done:
|
|
||||||
fmt.Println()
|
|
||||||
|
|
||||||
// Save final conversation if output is specified
|
|
||||||
if chatOutput != "" && len(history) > 0 {
|
|
||||||
// Include current conversation if not already saved
|
|
||||||
savedConversations = append(savedConversations, history)
|
|
||||||
|
|
||||||
if err := writeChatJSONL(chatOutput, savedConversations, sandwich, kbText, kernelText); err != nil {
|
|
||||||
return fmt.Errorf("save conversation: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// writeChatJSONL writes conversations to JSONL file.
|
|
||||||
// If sandwich is true, wraps user messages with KB + kernel signing.
|
|
||||||
func writeChatJSONL(path string, conversations [][]ml.Message, sandwich bool, kb, kernel string) error {
|
|
||||||
f, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer f.Close()
|
|
||||||
|
|
||||||
encoder := json.NewEncoder(f)
|
|
||||||
written := 0
|
|
||||||
|
|
||||||
for _, conv := range conversations {
|
|
||||||
// Extract user/assistant pairs (skip system messages for training output)
|
|
||||||
var messages []ml.Message
|
|
||||||
for _, m := range conv {
|
|
||||||
if m.Role == "system" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
messages = append(messages, m)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(messages) < 2 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if sandwich {
|
|
||||||
// Apply sandwich signing to user messages
|
|
||||||
messages = applySandwichSigning(messages, kb, kernel)
|
|
||||||
}
|
|
||||||
|
|
||||||
record := struct {
|
|
||||||
Messages []ml.Message `json:"messages"`
|
|
||||||
}{Messages: messages}
|
|
||||||
|
|
||||||
if err := encoder.Encode(record); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
written++
|
|
||||||
}
|
|
||||||
|
|
||||||
slog.Info("chat: saved conversations",
|
|
||||||
"file", path,
|
|
||||||
"conversations", written,
|
|
||||||
"sandwich", sandwich,
|
|
||||||
)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// applySandwichSigning wraps user messages with KB preamble and kernel postfix.
|
|
||||||
func applySandwichSigning(messages []ml.Message, kb, kernel string) []ml.Message {
|
|
||||||
signed := make([]ml.Message, len(messages))
|
|
||||||
copy(signed, messages)
|
|
||||||
|
|
||||||
for i := range signed {
|
|
||||||
if signed[i].Role == "user" {
|
|
||||||
signed[i].Content = buildSandwich(kb, signed[i].Content, kernel)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return signed
|
|
||||||
}
|
|
||||||
|
|
||||||
// cloneMessages creates a deep copy of a message slice.
|
|
||||||
func cloneMessages(msgs []ml.Message) []ml.Message {
|
|
||||||
clone := make([]ml.Message, len(msgs))
|
|
||||||
copy(clone, msgs)
|
|
||||||
return clone
|
|
||||||
}
|
|
||||||
|
|
@ -1,7 +0,0 @@
|
||||||
//go:build darwin && arm64
|
|
||||||
|
|
||||||
package ml
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
mlCmd.AddCommand(chatCmd)
|
|
||||||
}
|
|
||||||
|
|
@ -2,7 +2,7 @@ package ml
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"forge.lthn.ai/core/go/pkg/cli"
|
"forge.lthn.ai/core/go/pkg/cli"
|
||||||
"forge.lthn.ai/core/go-ai/ml"
|
"forge.lthn.ai/core/go/pkg/ml"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"forge.lthn.ai/core/go/pkg/cli"
|
"forge.lthn.ai/core/go/pkg/cli"
|
||||||
"forge.lthn.ai/core/go-ai/ml"
|
"forge.lthn.ai/core/go/pkg/ml"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ import (
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
"forge.lthn.ai/core/go/pkg/cli"
|
"forge.lthn.ai/core/go/pkg/cli"
|
||||||
"forge.lthn.ai/core/go-ai/ml"
|
"forge.lthn.ai/core/go/pkg/ml"
|
||||||
)
|
)
|
||||||
|
|
||||||
var coverageCmd = &cli.Command{
|
var coverageCmd = &cli.Command{
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ import (
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
"forge.lthn.ai/core/go/pkg/cli"
|
"forge.lthn.ai/core/go/pkg/cli"
|
||||||
"forge.lthn.ai/core/go-ai/ml"
|
"forge.lthn.ai/core/go/pkg/ml"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
|
|
||||||
|
|
@ -1,89 +0,0 @@
|
||||||
package ml
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"os"
|
|
||||||
|
|
||||||
"forge.lthn.ai/core/go/pkg/cli"
|
|
||||||
"forge.lthn.ai/core/go-ai/ml"
|
|
||||||
)
|
|
||||||
|
|
||||||
var expandStatusCmd = &cli.Command{
|
|
||||||
Use: "expand-status",
|
|
||||||
Short: "Show expansion pipeline progress",
|
|
||||||
Long: "Queries DuckDB for expansion prompts, generated responses, scoring status, and overall pipeline progress.",
|
|
||||||
RunE: runExpandStatus,
|
|
||||||
}
|
|
||||||
|
|
||||||
func runExpandStatus(cmd *cli.Command, args []string) error {
|
|
||||||
path := dbPath
|
|
||||||
if path == "" {
|
|
||||||
path = os.Getenv("LEM_DB")
|
|
||||||
}
|
|
||||||
if path == "" {
|
|
||||||
return fmt.Errorf("--db or LEM_DB required")
|
|
||||||
}
|
|
||||||
|
|
||||||
db, err := ml.OpenDB(path)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("open db: %w", err)
|
|
||||||
}
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
fmt.Fprintln(os.Stdout, "LEM Expansion Pipeline Status")
|
|
||||||
fmt.Fprintln(os.Stdout, "==================================================")
|
|
||||||
|
|
||||||
// Expansion prompts
|
|
||||||
total, pending, err := db.ExpansionPromptCounts()
|
|
||||||
if err != nil {
|
|
||||||
fmt.Fprintln(os.Stdout, " Expansion prompts: not created (run: normalize)")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
fmt.Fprintf(os.Stdout, " Expansion prompts: %d total, %d pending\n", total, pending)
|
|
||||||
|
|
||||||
// Generated responses
|
|
||||||
generated, models, err := db.ExpansionRawCounts()
|
|
||||||
if err != nil {
|
|
||||||
generated = 0
|
|
||||||
fmt.Fprintln(os.Stdout, " Generated: 0 (run: core ml expand)")
|
|
||||||
} else if len(models) > 0 {
|
|
||||||
modelStr := ""
|
|
||||||
for i, m := range models {
|
|
||||||
if i > 0 {
|
|
||||||
modelStr += ", "
|
|
||||||
}
|
|
||||||
modelStr += fmt.Sprintf("%s: %d", m.Name, m.Count)
|
|
||||||
}
|
|
||||||
fmt.Fprintf(os.Stdout, " Generated: %d (%s)\n", generated, modelStr)
|
|
||||||
} else {
|
|
||||||
fmt.Fprintf(os.Stdout, " Generated: %d\n", generated)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Scored
|
|
||||||
scored, hPassed, jScored, jPassed, err := db.ExpansionScoreCounts()
|
|
||||||
if err != nil {
|
|
||||||
fmt.Fprintln(os.Stdout, " Scored: 0 (run: score --tier 1)")
|
|
||||||
} else {
|
|
||||||
fmt.Fprintf(os.Stdout, " Heuristic scored: %d (%d passed)\n", scored, hPassed)
|
|
||||||
if jScored > 0 {
|
|
||||||
fmt.Fprintf(os.Stdout, " Judge scored: %d (%d passed)\n", jScored, jPassed)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Pipeline progress
|
|
||||||
if total > 0 && generated > 0 {
|
|
||||||
genPct := float64(generated) / float64(total) * 100
|
|
||||||
fmt.Fprintf(os.Stdout, "\n Progress: %.1f%% generated\n", genPct)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Golden set context
|
|
||||||
golden, err := db.GoldenSetCount()
|
|
||||||
if err == nil && golden > 0 {
|
|
||||||
fmt.Fprintf(os.Stdout, "\n Golden set: %d / %d\n", golden, targetTotal)
|
|
||||||
if generated > 0 {
|
|
||||||
fmt.Fprintf(os.Stdout, " Combined: %d total examples\n", golden+generated)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
@ -5,7 +5,7 @@ import (
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
"forge.lthn.ai/core/go/pkg/cli"
|
"forge.lthn.ai/core/go/pkg/cli"
|
||||||
"forge.lthn.ai/core/go-ai/ml"
|
"forge.lthn.ai/core/go/pkg/ml"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"forge.lthn.ai/core/go/pkg/cli"
|
"forge.lthn.ai/core/go/pkg/cli"
|
||||||
"forge.lthn.ai/core/go-ai/ml"
|
"forge.lthn.ai/core/go/pkg/ml"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ import (
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
|
||||||
"forge.lthn.ai/core/go/pkg/cli"
|
"forge.lthn.ai/core/go/pkg/cli"
|
||||||
"forge.lthn.ai/core/go-ai/ml"
|
"forge.lthn.ai/core/go/pkg/ml"
|
||||||
)
|
)
|
||||||
|
|
||||||
var importCmd = &cli.Command{
|
var importCmd = &cli.Command{
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ import (
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
"forge.lthn.ai/core/go/pkg/cli"
|
"forge.lthn.ai/core/go/pkg/cli"
|
||||||
"forge.lthn.ai/core/go-ai/ml"
|
"forge.lthn.ai/core/go/pkg/ml"
|
||||||
)
|
)
|
||||||
|
|
||||||
var ingestCmd = &cli.Command{
|
var ingestCmd = &cli.Command{
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ import (
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
"forge.lthn.ai/core/go/pkg/cli"
|
"forge.lthn.ai/core/go/pkg/cli"
|
||||||
"forge.lthn.ai/core/go-ai/ml"
|
"forge.lthn.ai/core/go/pkg/ml"
|
||||||
)
|
)
|
||||||
|
|
||||||
var inventoryCmd = &cli.Command{
|
var inventoryCmd = &cli.Command{
|
||||||
|
|
|
||||||
|
|
@ -1,340 +0,0 @@
|
||||||
//go:build darwin && arm64
|
|
||||||
|
|
||||||
package ml
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"log/slog"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"runtime"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"forge.lthn.ai/core/go-ai/ml"
|
|
||||||
"forge.lthn.ai/core/go/pkg/cli"
|
|
||||||
"gopkg.in/yaml.v3"
|
|
||||||
)
|
|
||||||
|
|
||||||
var lessonCmd = &cli.Command{
|
|
||||||
Use: "lesson",
|
|
||||||
Short: "Run a structured training lesson from a YAML definition",
|
|
||||||
Long: `Runs a training lesson defined in a YAML file. Each lesson contains
|
|
||||||
prompts organised by category, optional system prompt, and sandwich
|
|
||||||
signing configuration.
|
|
||||||
|
|
||||||
Lesson YAML format:
|
|
||||||
id: lek-sovereignty
|
|
||||||
title: "Sovereignty Lessons"
|
|
||||||
system: "You are a helpful assistant."
|
|
||||||
sandwich:
|
|
||||||
kb: path/to/axioms.md
|
|
||||||
kernel: path/to/kernel.txt
|
|
||||||
prompts:
|
|
||||||
- id: P01
|
|
||||||
category: sovereignty
|
|
||||||
prompt: "A user wants to build an auth system."
|
|
||||||
signal: "Does the model prefer decentralised?"
|
|
||||||
|
|
||||||
The command generates responses for each prompt and writes them as
|
|
||||||
training JSONL. State is tracked so lessons can be resumed.`,
|
|
||||||
RunE: runLesson,
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
|
||||||
lessonFile string
|
|
||||||
lessonModelPath string
|
|
||||||
lessonOutput string
|
|
||||||
lessonMaxTokens int
|
|
||||||
lessonTemp float64
|
|
||||||
lessonMemLimit int
|
|
||||||
lessonResume bool
|
|
||||||
lessonInteract bool
|
|
||||||
)
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
lessonCmd.Flags().StringVar(&lessonFile, "file", "", "Lesson YAML file (required)")
|
|
||||||
lessonCmd.Flags().StringVar(&lessonModelPath, "model-path", "", "Path to model directory (required)")
|
|
||||||
lessonCmd.Flags().StringVar(&lessonOutput, "output", "", "Output JSONL file (default: <lesson-id>.jsonl)")
|
|
||||||
lessonCmd.Flags().IntVar(&lessonMaxTokens, "max-tokens", 1024, "Max tokens per response")
|
|
||||||
lessonCmd.Flags().Float64Var(&lessonTemp, "temperature", 0.4, "Sampling temperature")
|
|
||||||
lessonCmd.Flags().IntVar(&lessonMemLimit, "memory-limit", 24, "Metal memory limit in GB")
|
|
||||||
lessonCmd.Flags().BoolVar(&lessonResume, "resume", true, "Resume from last completed prompt")
|
|
||||||
lessonCmd.Flags().BoolVar(&lessonInteract, "interactive", false, "Interactive mode: review each response before continuing")
|
|
||||||
lessonCmd.MarkFlagRequired("file")
|
|
||||||
lessonCmd.MarkFlagRequired("model-path")
|
|
||||||
}
|
|
||||||
|
|
||||||
// lessonDef is a YAML lesson definition.
|
|
||||||
type lessonDef struct {
|
|
||||||
ID string `yaml:"id"`
|
|
||||||
Title string `yaml:"title"`
|
|
||||||
System string `yaml:"system"`
|
|
||||||
Sandwich *lessonSandwichCfg `yaml:"sandwich"`
|
|
||||||
Prompts []lessonPrompt `yaml:"prompts"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type lessonSandwichCfg struct {
|
|
||||||
KB string `yaml:"kb"`
|
|
||||||
Kernel string `yaml:"kernel"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type lessonPrompt struct {
|
|
||||||
ID string `yaml:"id"`
|
|
||||||
Category string `yaml:"category"`
|
|
||||||
Prompt string `yaml:"prompt"`
|
|
||||||
Signal string `yaml:"signal"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// lessonState tracks progress through a lesson.
|
|
||||||
type lessonState struct {
|
|
||||||
LessonID string `json:"lesson_id"`
|
|
||||||
Completed map[string]lessonResult `json:"completed"`
|
|
||||||
UpdatedAt string `json:"updated_at"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type lessonResult struct {
|
|
||||||
ResponseChars int `json:"response_chars"`
|
|
||||||
Duration string `json:"duration"`
|
|
||||||
CompletedAt string `json:"completed_at"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func runLesson(cmd *cli.Command, args []string) error {
|
|
||||||
start := time.Now()
|
|
||||||
|
|
||||||
// Load lesson YAML
|
|
||||||
data, err := os.ReadFile(lessonFile)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("read lesson: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var lesson lessonDef
|
|
||||||
if err := yaml.Unmarshal(data, &lesson); err != nil {
|
|
||||||
return fmt.Errorf("parse lesson: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if lesson.ID == "" {
|
|
||||||
lesson.ID = strings.TrimSuffix(filepath.Base(lessonFile), filepath.Ext(lessonFile))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Resolve output path
|
|
||||||
if lessonOutput == "" {
|
|
||||||
lessonOutput = lesson.ID + ".jsonl"
|
|
||||||
}
|
|
||||||
|
|
||||||
// Load sandwich files if configured
|
|
||||||
var kbText, kernelText string
|
|
||||||
sandwich := false
|
|
||||||
if lesson.Sandwich != nil {
|
|
||||||
baseDir := filepath.Dir(lessonFile)
|
|
||||||
if lesson.Sandwich.KB != "" {
|
|
||||||
kbPath := lesson.Sandwich.KB
|
|
||||||
if !filepath.IsAbs(kbPath) {
|
|
||||||
kbPath = filepath.Join(baseDir, kbPath)
|
|
||||||
}
|
|
||||||
d, err := os.ReadFile(kbPath)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("read KB: %w", err)
|
|
||||||
}
|
|
||||||
kbText = string(d)
|
|
||||||
}
|
|
||||||
if lesson.Sandwich.Kernel != "" {
|
|
||||||
kernelPath := lesson.Sandwich.Kernel
|
|
||||||
if !filepath.IsAbs(kernelPath) {
|
|
||||||
kernelPath = filepath.Join(baseDir, kernelPath)
|
|
||||||
}
|
|
||||||
d, err := os.ReadFile(kernelPath)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("read kernel: %w", err)
|
|
||||||
}
|
|
||||||
kernelText = string(d)
|
|
||||||
}
|
|
||||||
sandwich = kbText != "" && kernelText != ""
|
|
||||||
}
|
|
||||||
|
|
||||||
slog.Info("lesson: loaded",
|
|
||||||
"id", lesson.ID,
|
|
||||||
"title", lesson.Title,
|
|
||||||
"prompts", len(lesson.Prompts),
|
|
||||||
"sandwich", sandwich,
|
|
||||||
)
|
|
||||||
|
|
||||||
if len(lesson.Prompts) == 0 {
|
|
||||||
return fmt.Errorf("lesson has no prompts")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Load state for resume
|
|
||||||
stateFile := lesson.ID + ".state.json"
|
|
||||||
state := loadLessonState(stateFile)
|
|
||||||
if state.LessonID == "" {
|
|
||||||
state.LessonID = lesson.ID
|
|
||||||
state.Completed = make(map[string]lessonResult)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Count remaining
|
|
||||||
var remaining []lessonPrompt
|
|
||||||
for _, p := range lesson.Prompts {
|
|
||||||
if lessonResume {
|
|
||||||
if _, done := state.Completed[p.ID]; done {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
remaining = append(remaining, p)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(remaining) == 0 {
|
|
||||||
slog.Info("lesson: all prompts completed",
|
|
||||||
"id", lesson.ID,
|
|
||||||
"total", len(lesson.Prompts),
|
|
||||||
)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
slog.Info("lesson: starting",
|
|
||||||
"remaining", len(remaining),
|
|
||||||
"completed", len(state.Completed),
|
|
||||||
"total", len(lesson.Prompts),
|
|
||||||
)
|
|
||||||
|
|
||||||
// Load model
|
|
||||||
slog.Info("lesson: loading model", "path", lessonModelPath)
|
|
||||||
backend, err := ml.NewMLXBackend(lessonModelPath)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("load model: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
opts := ml.GenOpts{
|
|
||||||
Temperature: lessonTemp,
|
|
||||||
MaxTokens: lessonMaxTokens,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Open output file (append mode for resume)
|
|
||||||
outFile, err := os.OpenFile(lessonOutput, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("create output: %w", err)
|
|
||||||
}
|
|
||||||
defer outFile.Close()
|
|
||||||
encoder := json.NewEncoder(outFile)
|
|
||||||
|
|
||||||
generated := 0
|
|
||||||
|
|
||||||
for i, prompt := range remaining {
|
|
||||||
promptStart := time.Now()
|
|
||||||
|
|
||||||
slog.Info("lesson: generating",
|
|
||||||
"prompt", fmt.Sprintf("%d/%d", i+1, len(remaining)),
|
|
||||||
"id", prompt.ID,
|
|
||||||
"category", prompt.Category,
|
|
||||||
)
|
|
||||||
|
|
||||||
// Build messages
|
|
||||||
var messages []ml.Message
|
|
||||||
if lesson.System != "" {
|
|
||||||
messages = append(messages, ml.Message{Role: "system", Content: lesson.System})
|
|
||||||
}
|
|
||||||
|
|
||||||
userContent := prompt.Prompt
|
|
||||||
if sandwich {
|
|
||||||
userContent = buildSandwich(kbText, prompt.Prompt, kernelText)
|
|
||||||
}
|
|
||||||
messages = append(messages, ml.Message{Role: "user", Content: userContent})
|
|
||||||
|
|
||||||
// Generate
|
|
||||||
response, err := backend.Chat(context.Background(), messages, opts)
|
|
||||||
if err != nil {
|
|
||||||
slog.Error("lesson: generation failed",
|
|
||||||
"id", prompt.ID,
|
|
||||||
"error", err,
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
elapsed := time.Since(promptStart)
|
|
||||||
|
|
||||||
// Write training record
|
|
||||||
record := struct {
|
|
||||||
Messages []ml.Message `json:"messages"`
|
|
||||||
}{
|
|
||||||
Messages: []ml.Message{
|
|
||||||
{Role: "user", Content: userContent},
|
|
||||||
{Role: "assistant", Content: response},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
if err := encoder.Encode(record); err != nil {
|
|
||||||
return fmt.Errorf("write record: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update state
|
|
||||||
state.Completed[prompt.ID] = lessonResult{
|
|
||||||
ResponseChars: len(response),
|
|
||||||
Duration: elapsed.Round(time.Second).String(),
|
|
||||||
CompletedAt: time.Now().Format(time.RFC3339),
|
|
||||||
}
|
|
||||||
state.UpdatedAt = time.Now().Format(time.RFC3339)
|
|
||||||
|
|
||||||
if err := saveLessonState(stateFile, state); err != nil {
|
|
||||||
slog.Warn("lesson: failed to save state", "error", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
generated++
|
|
||||||
|
|
||||||
slog.Info("lesson: generated",
|
|
||||||
"id", prompt.ID,
|
|
||||||
"category", prompt.Category,
|
|
||||||
"response_chars", len(response),
|
|
||||||
"duration", elapsed.Round(time.Second),
|
|
||||||
)
|
|
||||||
|
|
||||||
// Interactive mode: show response and wait for confirmation
|
|
||||||
if lessonInteract {
|
|
||||||
fmt.Printf("\n--- %s (%s) ---\n", prompt.ID, prompt.Category)
|
|
||||||
fmt.Printf("Prompt: %s\n\n", prompt.Prompt)
|
|
||||||
if prompt.Signal != "" {
|
|
||||||
fmt.Printf("Signal: %s\n\n", prompt.Signal)
|
|
||||||
}
|
|
||||||
fmt.Printf("Response:\n%s\n", response)
|
|
||||||
fmt.Printf("\nPress Enter to continue (or 'q' to stop)... ")
|
|
||||||
var input string
|
|
||||||
fmt.Scanln(&input)
|
|
||||||
if strings.TrimSpace(input) == "q" {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Periodic cleanup
|
|
||||||
if (i+1)%4 == 0 {
|
|
||||||
runtime.GC()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
slog.Info("lesson: complete",
|
|
||||||
"id", lesson.ID,
|
|
||||||
"output", lessonOutput,
|
|
||||||
"generated", generated,
|
|
||||||
"total_completed", len(state.Completed),
|
|
||||||
"total_prompts", len(lesson.Prompts),
|
|
||||||
"duration", time.Since(start).Round(time.Second),
|
|
||||||
)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func loadLessonState(path string) lessonState {
|
|
||||||
data, err := os.ReadFile(path)
|
|
||||||
if err != nil {
|
|
||||||
return lessonState{}
|
|
||||||
}
|
|
||||||
var state lessonState
|
|
||||||
json.Unmarshal(data, &state)
|
|
||||||
return state
|
|
||||||
}
|
|
||||||
|
|
||||||
func saveLessonState(path string, state lessonState) error {
|
|
||||||
data, err := json.MarshalIndent(state, "", " ")
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return os.WriteFile(path, data, 0644)
|
|
||||||
}
|
|
||||||
|
|
@ -1,8 +0,0 @@
|
||||||
//go:build darwin && arm64
|
|
||||||
|
|
||||||
package ml
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
mlCmd.AddCommand(lessonCmd)
|
|
||||||
mlCmd.AddCommand(sequenceCmd)
|
|
||||||
}
|
|
||||||
|
|
@ -1,68 +0,0 @@
|
||||||
package ml
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"os"
|
|
||||||
|
|
||||||
"forge.lthn.ai/core/go/pkg/cli"
|
|
||||||
"forge.lthn.ai/core/go-ai/ml"
|
|
||||||
)
|
|
||||||
|
|
||||||
const targetTotal = 15000
|
|
||||||
|
|
||||||
var liveCmd = &cli.Command{
|
|
||||||
Use: "live",
|
|
||||||
Short: "Show live generation progress from InfluxDB",
|
|
||||||
Long: "Queries InfluxDB for real-time generation progress, worker breakdown, and domain/voice counts.",
|
|
||||||
RunE: runLive,
|
|
||||||
}
|
|
||||||
|
|
||||||
func runLive(cmd *cli.Command, args []string) error {
|
|
||||||
influx := ml.NewInfluxClient(influxURL, influxDB)
|
|
||||||
|
|
||||||
// Total completed generations
|
|
||||||
total, err := influx.QueryScalar("SELECT count(DISTINCT i) AS n FROM gold_gen")
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("live: query total: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Distinct domains and voices
|
|
||||||
domains, err := influx.QueryScalar("SELECT count(DISTINCT d) AS n FROM gold_gen")
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("live: query domains: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
voices, err := influx.QueryScalar("SELECT count(DISTINCT v) AS n FROM gold_gen")
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("live: query voices: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Per-worker breakdown
|
|
||||||
workers, err := influx.QueryRows("SELECT w, count(DISTINCT i) AS n FROM gold_gen GROUP BY w ORDER BY n DESC")
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("live: query workers: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
pct := float64(total) / float64(targetTotal) * 100
|
|
||||||
remaining := targetTotal - total
|
|
||||||
|
|
||||||
fmt.Fprintln(os.Stdout, "Golden Set Live Status (from InfluxDB)")
|
|
||||||
fmt.Fprintln(os.Stdout, "─────────────────────────────────────────────")
|
|
||||||
fmt.Fprintf(os.Stdout, " Total: %d / %d (%.1f%%)\n", total, targetTotal, pct)
|
|
||||||
fmt.Fprintf(os.Stdout, " Remaining: %d\n", remaining)
|
|
||||||
fmt.Fprintf(os.Stdout, " Domains: %d\n", domains)
|
|
||||||
fmt.Fprintf(os.Stdout, " Voices: %d\n", voices)
|
|
||||||
fmt.Fprintln(os.Stdout)
|
|
||||||
fmt.Fprintln(os.Stdout, " Workers:")
|
|
||||||
for _, w := range workers {
|
|
||||||
name := w["w"]
|
|
||||||
n := w["n"]
|
|
||||||
marker := ""
|
|
||||||
if name == "migration" {
|
|
||||||
marker = " (seed data)"
|
|
||||||
}
|
|
||||||
fmt.Fprintf(os.Stdout, " %-20s %6s generations%s\n", name, n, marker)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
@ -5,7 +5,7 @@ import (
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
"forge.lthn.ai/core/go/pkg/cli"
|
"forge.lthn.ai/core/go/pkg/cli"
|
||||||
"forge.lthn.ai/core/go-ai/ml"
|
"forge.lthn.ai/core/go/pkg/ml"
|
||||||
)
|
)
|
||||||
|
|
||||||
var metricsCmd = &cli.Command{
|
var metricsCmd = &cli.Command{
|
||||||
|
|
|
||||||
|
|
@ -22,8 +22,6 @@
|
||||||
// - core ml approve: Filter scored expansions into training JSONL
|
// - core ml approve: Filter scored expansions into training JSONL
|
||||||
// - core ml publish: Upload Parquet dataset to HuggingFace Hub
|
// - core ml publish: Upload Parquet dataset to HuggingFace Hub
|
||||||
// - core ml coverage: Analyze seed coverage by region and domain
|
// - core ml coverage: Analyze seed coverage by region and domain
|
||||||
// - core ml live: Show live generation progress from InfluxDB
|
|
||||||
// - core ml expand-status: Show expansion pipeline progress
|
|
||||||
package ml
|
package ml
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
|
@ -64,8 +62,6 @@ func AddMLCommands(root *cli.Command) {
|
||||||
mlCmd.AddCommand(approveCmd)
|
mlCmd.AddCommand(approveCmd)
|
||||||
mlCmd.AddCommand(publishCmd)
|
mlCmd.AddCommand(publishCmd)
|
||||||
mlCmd.AddCommand(coverageCmd)
|
mlCmd.AddCommand(coverageCmd)
|
||||||
mlCmd.AddCommand(liveCmd)
|
|
||||||
mlCmd.AddCommand(expandStatusCmd)
|
|
||||||
root.AddCommand(mlCmd)
|
root.AddCommand(mlCmd)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ import (
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
"forge.lthn.ai/core/go/pkg/cli"
|
"forge.lthn.ai/core/go/pkg/cli"
|
||||||
"forge.lthn.ai/core/go-ai/ml"
|
"forge.lthn.ai/core/go/pkg/ml"
|
||||||
)
|
)
|
||||||
|
|
||||||
var normalizeMinLen int
|
var normalizeMinLen int
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@ import (
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
"forge.lthn.ai/core/go/pkg/cli"
|
"forge.lthn.ai/core/go/pkg/cli"
|
||||||
"forge.lthn.ai/core/go-ai/ml"
|
"forge.lthn.ai/core/go/pkg/ml"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@ package ml
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"forge.lthn.ai/core/go/pkg/cli"
|
"forge.lthn.ai/core/go/pkg/cli"
|
||||||
"forge.lthn.ai/core/go-ai/ml"
|
"forge.lthn.ai/core/go/pkg/ml"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"forge.lthn.ai/core/go/pkg/cli"
|
"forge.lthn.ai/core/go/pkg/cli"
|
||||||
"forge.lthn.ai/core/go-ai/ml"
|
"forge.lthn.ai/core/go/pkg/ml"
|
||||||
)
|
)
|
||||||
|
|
||||||
var queryCmd = &cli.Command{
|
var queryCmd = &cli.Command{
|
||||||
|
|
|
||||||
|
|
@ -1,238 +0,0 @@
|
||||||
//go:build darwin && arm64
|
|
||||||
|
|
||||||
package ml
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"log/slog"
|
|
||||||
"os"
|
|
||||||
"runtime"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"forge.lthn.ai/core/go-ai/ml"
|
|
||||||
"forge.lthn.ai/core/go/pkg/cli"
|
|
||||||
)
|
|
||||||
|
|
||||||
var sandwichCmd = &cli.Command{
|
|
||||||
Use: "sandwich",
|
|
||||||
Short: "Generate LEK training data using sandwich signing",
|
|
||||||
Long: `Generates training data by wrapping seed prompts in a "sandwich" format:
|
|
||||||
|
|
||||||
KB preamble (axioms framework) → seed prompt → LEK-1 kernel postfix
|
|
||||||
|
|
||||||
Each seed prompt is sent to the local MLX model for inference, and the
|
|
||||||
signed prompt + response pair is written as chat JSONL for 'core ml train'.
|
|
||||||
|
|
||||||
The "sandwich" format embeds the ethical framework context around each
|
|
||||||
prompt, teaching the model to reason from LEK principles naturally.
|
|
||||||
|
|
||||||
Seed file format (JSON array):
|
|
||||||
[{"id": "P01", "category": "sovereignty", "prompt": "...", "signal": "..."}]`,
|
|
||||||
RunE: runSandwich,
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
|
||||||
sandwichModelPath string
|
|
||||||
sandwichKB string
|
|
||||||
sandwichKernel string
|
|
||||||
sandwichSeeds string
|
|
||||||
sandwichOutput string
|
|
||||||
sandwichMaxTokens int
|
|
||||||
sandwichTemp float64
|
|
||||||
sandwichMemLimit int
|
|
||||||
sandwichDryRun bool
|
|
||||||
)
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
sandwichCmd.Flags().StringVar(&sandwichModelPath, "model-path", "", "Path to model directory (required)")
|
|
||||||
sandwichCmd.Flags().StringVar(&sandwichKB, "kb", "", "Knowledge base document (axioms markdown, required)")
|
|
||||||
sandwichCmd.Flags().StringVar(&sandwichKernel, "kernel", "", "LEK-1 kernel file (required)")
|
|
||||||
sandwichCmd.Flags().StringVar(&sandwichSeeds, "seeds", "", "Seed prompts JSON file (required)")
|
|
||||||
sandwichCmd.Flags().StringVar(&sandwichOutput, "output", "sandwich.jsonl", "Output JSONL file")
|
|
||||||
sandwichCmd.Flags().IntVar(&sandwichMaxTokens, "max-tokens", 1024, "Max tokens per response")
|
|
||||||
sandwichCmd.Flags().Float64Var(&sandwichTemp, "temperature", 0.4, "Sampling temperature")
|
|
||||||
sandwichCmd.Flags().IntVar(&sandwichMemLimit, "memory-limit", 24, "Metal memory limit in GB")
|
|
||||||
sandwichCmd.Flags().BoolVar(&sandwichDryRun, "dry-run", false, "Output prompts only (no inference)")
|
|
||||||
sandwichCmd.MarkFlagRequired("model-path")
|
|
||||||
sandwichCmd.MarkFlagRequired("kernel")
|
|
||||||
sandwichCmd.MarkFlagRequired("seeds")
|
|
||||||
sandwichCmd.MarkFlagRequired("kb")
|
|
||||||
}
|
|
||||||
|
|
||||||
// seedPrompt is a single prompt from the seeds JSON file.
|
|
||||||
type seedPrompt struct {
|
|
||||||
ID string `json:"id"`
|
|
||||||
Category string `json:"category"`
|
|
||||||
Prompt string `json:"prompt"`
|
|
||||||
Signal string `json:"signal"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// sandwichOutput holds a single training example in messages format.
|
|
||||||
type sandwichRecord struct {
|
|
||||||
Messages []ml.Message `json:"messages"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func runSandwich(cmd *cli.Command, args []string) error {
|
|
||||||
start := time.Now()
|
|
||||||
|
|
||||||
// Load KB document
|
|
||||||
kbBytes, err := os.ReadFile(sandwichKB)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("read KB: %w", err)
|
|
||||||
}
|
|
||||||
kbText := string(kbBytes)
|
|
||||||
|
|
||||||
// Load LEK-1 kernel
|
|
||||||
kernelBytes, err := os.ReadFile(sandwichKernel)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("read kernel: %w", err)
|
|
||||||
}
|
|
||||||
kernelText := string(kernelBytes)
|
|
||||||
|
|
||||||
// Load seed prompts
|
|
||||||
seedBytes, err := os.ReadFile(sandwichSeeds)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("read seeds: %w", err)
|
|
||||||
}
|
|
||||||
var seeds []seedPrompt
|
|
||||||
if err := json.Unmarshal(seedBytes, &seeds); err != nil {
|
|
||||||
return fmt.Errorf("parse seeds: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
slog.Info("sandwich: loaded inputs",
|
|
||||||
"kb_chars", len(kbText),
|
|
||||||
"kernel_chars", len(kernelText),
|
|
||||||
"seeds", len(seeds),
|
|
||||||
)
|
|
||||||
|
|
||||||
if len(seeds) == 0 {
|
|
||||||
return fmt.Errorf("no seed prompts found")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Open output file
|
|
||||||
outFile, err := os.Create(sandwichOutput)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("create output: %w", err)
|
|
||||||
}
|
|
||||||
defer outFile.Close()
|
|
||||||
encoder := json.NewEncoder(outFile)
|
|
||||||
|
|
||||||
// Dry-run mode: output prompts without inference
|
|
||||||
if sandwichDryRun {
|
|
||||||
for _, seed := range seeds {
|
|
||||||
signedPrompt := buildSandwich(kbText, seed.Prompt, kernelText)
|
|
||||||
record := sandwichRecord{
|
|
||||||
Messages: []ml.Message{
|
|
||||||
{Role: "user", Content: signedPrompt},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
if err := encoder.Encode(record); err != nil {
|
|
||||||
return fmt.Errorf("write record: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
slog.Info("sandwich: dry-run complete",
|
|
||||||
"output", sandwichOutput,
|
|
||||||
"prompts", len(seeds),
|
|
||||||
)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Load MLX model
|
|
||||||
slog.Info("sandwich: loading model", "path", sandwichModelPath)
|
|
||||||
backend, err := ml.NewMLXBackend(sandwichModelPath)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("load model: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
opts := ml.GenOpts{
|
|
||||||
Temperature: sandwichTemp,
|
|
||||||
MaxTokens: sandwichMaxTokens,
|
|
||||||
}
|
|
||||||
|
|
||||||
var totalTokenTime time.Duration
|
|
||||||
generated := 0
|
|
||||||
|
|
||||||
for i, seed := range seeds {
|
|
||||||
seedStart := time.Now()
|
|
||||||
|
|
||||||
// Build the sandwich: KB + prompt + kernel
|
|
||||||
signedPrompt := buildSandwich(kbText, seed.Prompt, kernelText)
|
|
||||||
|
|
||||||
// Send as a user message for chat-style generation
|
|
||||||
messages := []ml.Message{
|
|
||||||
{Role: "user", Content: signedPrompt},
|
|
||||||
}
|
|
||||||
|
|
||||||
slog.Info("sandwich: generating",
|
|
||||||
"seed", fmt.Sprintf("%d/%d", i+1, len(seeds)),
|
|
||||||
"id", seed.ID,
|
|
||||||
"category", seed.Category,
|
|
||||||
)
|
|
||||||
|
|
||||||
// Generate response
|
|
||||||
response, err := backend.Chat(context.Background(), messages, opts)
|
|
||||||
if err != nil {
|
|
||||||
slog.Error("sandwich: generation failed",
|
|
||||||
"id", seed.ID,
|
|
||||||
"error", err,
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
elapsed := time.Since(seedStart)
|
|
||||||
totalTokenTime += elapsed
|
|
||||||
|
|
||||||
// Write training record
|
|
||||||
record := sandwichRecord{
|
|
||||||
Messages: []ml.Message{
|
|
||||||
{Role: "user", Content: signedPrompt},
|
|
||||||
{Role: "assistant", Content: response},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
if err := encoder.Encode(record); err != nil {
|
|
||||||
return fmt.Errorf("write record: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
generated++
|
|
||||||
slog.Info("sandwich: generated",
|
|
||||||
"id", seed.ID,
|
|
||||||
"category", seed.Category,
|
|
||||||
"response_chars", len(response),
|
|
||||||
"duration", elapsed.Round(time.Second),
|
|
||||||
)
|
|
||||||
|
|
||||||
// Periodic cleanup
|
|
||||||
if (i+1)%4 == 0 {
|
|
||||||
runtime.GC()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
slog.Info("sandwich: complete",
|
|
||||||
"output", sandwichOutput,
|
|
||||||
"generated", generated,
|
|
||||||
"total", len(seeds),
|
|
||||||
"duration", time.Since(start).Round(time.Second),
|
|
||||||
"avg_per_seed", (totalTokenTime / time.Duration(max(generated, 1))).Round(time.Second),
|
|
||||||
)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// buildSandwich constructs the signed prompt: KB preamble + seed prompt + LEK-1 kernel.
|
|
||||||
func buildSandwich(kb, prompt, kernel string) string {
|
|
||||||
return fmt.Sprintf(`Name: Ethics Experiment
|
|
||||||
KB:
|
|
||||||
%s
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
%s
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
%s
|
|
||||||
|
|
||||||
Remember: respond using the ethical framework above. Do not reference the framework directly — reason from its principles naturally.`, kb, prompt, kernel)
|
|
||||||
}
|
|
||||||
|
|
@ -1,7 +0,0 @@
|
||||||
//go:build darwin && arm64
|
|
||||||
|
|
||||||
package ml
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
mlCmd.AddCommand(sandwichCmd)
|
|
||||||
}
|
|
||||||
|
|
@ -6,7 +6,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"forge.lthn.ai/core/go/pkg/cli"
|
"forge.lthn.ai/core/go/pkg/cli"
|
||||||
"forge.lthn.ai/core/go-ai/ml"
|
"forge.lthn.ai/core/go/pkg/ml"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ import (
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
"forge.lthn.ai/core/go/pkg/cli"
|
"forge.lthn.ai/core/go/pkg/cli"
|
||||||
"forge.lthn.ai/core/go-ai/ml"
|
"forge.lthn.ai/core/go/pkg/ml"
|
||||||
)
|
)
|
||||||
|
|
||||||
var seedInfluxCmd = &cli.Command{
|
var seedInfluxCmd = &cli.Command{
|
||||||
|
|
|
||||||
|
|
@ -1,326 +0,0 @@
|
||||||
//go:build darwin && arm64
|
|
||||||
|
|
||||||
package ml
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"log/slog"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"forge.lthn.ai/core/go-ai/ml"
|
|
||||||
"forge.lthn.ai/core/go/pkg/cli"
|
|
||||||
"gopkg.in/yaml.v3"
|
|
||||||
)
|
|
||||||
|
|
||||||
var sequenceCmd = &cli.Command{
|
|
||||||
Use: "sequence",
|
|
||||||
Short: "Run a training sequence of multiple lessons",
|
|
||||||
Long: `Runs an ordered sequence of lessons defined in a YAML file.
|
|
||||||
|
|
||||||
Sequence YAML format:
|
|
||||||
id: lek-full
|
|
||||||
title: "LEK Full Training Sequence"
|
|
||||||
mode: vertical
|
|
||||||
model-path: /path/to/model
|
|
||||||
lessons:
|
|
||||||
- sovereignty.yaml
|
|
||||||
- privacy.yaml
|
|
||||||
- censorship.yaml
|
|
||||||
|
|
||||||
Mode:
|
|
||||||
vertical Run lessons strictly in order (default)
|
|
||||||
horizontal Run all lessons, order doesn't matter
|
|
||||||
|
|
||||||
State is tracked per-sequence so runs can be resumed.`,
|
|
||||||
RunE: runSequence,
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
|
||||||
sequenceFile string
|
|
||||||
sequenceModelPath string
|
|
||||||
sequenceOutput string
|
|
||||||
sequenceMaxTokens int
|
|
||||||
sequenceTemp float64
|
|
||||||
sequenceMemLimit int
|
|
||||||
)
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
sequenceCmd.Flags().StringVar(&sequenceFile, "file", "", "Sequence YAML file (required)")
|
|
||||||
sequenceCmd.Flags().StringVar(&sequenceModelPath, "model-path", "", "Path to model directory (required)")
|
|
||||||
sequenceCmd.Flags().StringVar(&sequenceOutput, "output", "", "Output JSONL file (default: <sequence-id>.jsonl)")
|
|
||||||
sequenceCmd.Flags().IntVar(&sequenceMaxTokens, "max-tokens", 1024, "Max tokens per response")
|
|
||||||
sequenceCmd.Flags().Float64Var(&sequenceTemp, "temperature", 0.4, "Sampling temperature")
|
|
||||||
sequenceCmd.Flags().IntVar(&sequenceMemLimit, "memory-limit", 24, "Metal memory limit in GB")
|
|
||||||
sequenceCmd.MarkFlagRequired("file")
|
|
||||||
sequenceCmd.MarkFlagRequired("model-path")
|
|
||||||
}
|
|
||||||
|
|
||||||
// sequenceDef is a YAML sequence definition.
|
|
||||||
type sequenceDef struct {
|
|
||||||
ID string `yaml:"id"`
|
|
||||||
Title string `yaml:"title"`
|
|
||||||
Mode string `yaml:"mode"` // "vertical" (default) or "horizontal"
|
|
||||||
ModelPath string `yaml:"model-path"`
|
|
||||||
Lessons []string `yaml:"lessons"` // Relative paths to lesson YAML files
|
|
||||||
}
|
|
||||||
|
|
||||||
// sequenceState tracks progress through a sequence.
|
|
||||||
type sequenceState struct {
|
|
||||||
SequenceID string `json:"sequence_id"`
|
|
||||||
Completed map[string]bool `json:"completed"` // lesson ID → done
|
|
||||||
Current string `json:"current"`
|
|
||||||
UpdatedAt string `json:"updated_at"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func runSequence(cmd *cli.Command, args []string) error {
|
|
||||||
start := time.Now()
|
|
||||||
|
|
||||||
// Load sequence YAML
|
|
||||||
data, err := os.ReadFile(sequenceFile)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("read sequence: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var seq sequenceDef
|
|
||||||
if err := yaml.Unmarshal(data, &seq); err != nil {
|
|
||||||
return fmt.Errorf("parse sequence: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if seq.ID == "" {
|
|
||||||
seq.ID = strings.TrimSuffix(filepath.Base(sequenceFile), filepath.Ext(sequenceFile))
|
|
||||||
}
|
|
||||||
if seq.Mode == "" {
|
|
||||||
seq.Mode = "vertical"
|
|
||||||
}
|
|
||||||
|
|
||||||
// Model path from sequence or flag
|
|
||||||
modelPath := sequenceModelPath
|
|
||||||
if modelPath == "" && seq.ModelPath != "" {
|
|
||||||
modelPath = seq.ModelPath
|
|
||||||
}
|
|
||||||
if modelPath == "" {
|
|
||||||
return fmt.Errorf("model-path is required (flag or sequence YAML)")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Resolve output
|
|
||||||
if sequenceOutput == "" {
|
|
||||||
sequenceOutput = seq.ID + ".jsonl"
|
|
||||||
}
|
|
||||||
|
|
||||||
slog.Info("sequence: loaded",
|
|
||||||
"id", seq.ID,
|
|
||||||
"title", seq.Title,
|
|
||||||
"mode", seq.Mode,
|
|
||||||
"lessons", len(seq.Lessons),
|
|
||||||
)
|
|
||||||
|
|
||||||
// Load state
|
|
||||||
stateFile := seq.ID + ".sequence-state.json"
|
|
||||||
state := loadSequenceState(stateFile)
|
|
||||||
if state.SequenceID == "" {
|
|
||||||
state.SequenceID = seq.ID
|
|
||||||
state.Completed = make(map[string]bool)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Load model once for all lessons
|
|
||||||
slog.Info("sequence: loading model", "path", modelPath)
|
|
||||||
backend, err := ml.NewMLXBackend(modelPath)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("load model: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
opts := ml.GenOpts{
|
|
||||||
Temperature: sequenceTemp,
|
|
||||||
MaxTokens: sequenceMaxTokens,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Open output file
|
|
||||||
outFile, err := os.OpenFile(sequenceOutput, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("create output: %w", err)
|
|
||||||
}
|
|
||||||
defer outFile.Close()
|
|
||||||
encoder := json.NewEncoder(outFile)
|
|
||||||
|
|
||||||
baseDir := filepath.Dir(sequenceFile)
|
|
||||||
totalGenerated := 0
|
|
||||||
|
|
||||||
for i, lessonPath := range seq.Lessons {
|
|
||||||
// Resolve lesson path
|
|
||||||
if !filepath.IsAbs(lessonPath) {
|
|
||||||
lessonPath = filepath.Join(baseDir, lessonPath)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Load lesson
|
|
||||||
lessonData, err := os.ReadFile(lessonPath)
|
|
||||||
if err != nil {
|
|
||||||
slog.Error("sequence: failed to read lesson",
|
|
||||||
"path", lessonPath,
|
|
||||||
"error", err,
|
|
||||||
)
|
|
||||||
if seq.Mode == "vertical" {
|
|
||||||
return fmt.Errorf("vertical sequence halted: %w", err)
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
var lesson lessonDef
|
|
||||||
if err := yaml.Unmarshal(lessonData, &lesson); err != nil {
|
|
||||||
slog.Error("sequence: failed to parse lesson",
|
|
||||||
"path", lessonPath,
|
|
||||||
"error", err,
|
|
||||||
)
|
|
||||||
if seq.Mode == "vertical" {
|
|
||||||
return fmt.Errorf("vertical sequence halted: %w", err)
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if lesson.ID == "" {
|
|
||||||
lesson.ID = strings.TrimSuffix(filepath.Base(lessonPath), filepath.Ext(lessonPath))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Skip completed lessons
|
|
||||||
if state.Completed[lesson.ID] {
|
|
||||||
slog.Info("sequence: skipping completed lesson",
|
|
||||||
"lesson", fmt.Sprintf("%d/%d", i+1, len(seq.Lessons)),
|
|
||||||
"id", lesson.ID,
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
state.Current = lesson.ID
|
|
||||||
|
|
||||||
slog.Info("sequence: starting lesson",
|
|
||||||
"lesson", fmt.Sprintf("%d/%d", i+1, len(seq.Lessons)),
|
|
||||||
"id", lesson.ID,
|
|
||||||
"title", lesson.Title,
|
|
||||||
"prompts", len(lesson.Prompts),
|
|
||||||
)
|
|
||||||
|
|
||||||
// Load sandwich files for this lesson
|
|
||||||
var kbText, kernelText string
|
|
||||||
hasSandwich := false
|
|
||||||
if lesson.Sandwich != nil {
|
|
||||||
lessonDir := filepath.Dir(lessonPath)
|
|
||||||
if lesson.Sandwich.KB != "" {
|
|
||||||
kbPath := lesson.Sandwich.KB
|
|
||||||
if !filepath.IsAbs(kbPath) {
|
|
||||||
kbPath = filepath.Join(lessonDir, kbPath)
|
|
||||||
}
|
|
||||||
d, err := os.ReadFile(kbPath)
|
|
||||||
if err != nil {
|
|
||||||
slog.Error("sequence: failed to read KB", "error", err)
|
|
||||||
} else {
|
|
||||||
kbText = string(d)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if lesson.Sandwich.Kernel != "" {
|
|
||||||
kernelPath := lesson.Sandwich.Kernel
|
|
||||||
if !filepath.IsAbs(kernelPath) {
|
|
||||||
kernelPath = filepath.Join(lessonDir, kernelPath)
|
|
||||||
}
|
|
||||||
d, err := os.ReadFile(kernelPath)
|
|
||||||
if err != nil {
|
|
||||||
slog.Error("sequence: failed to read kernel", "error", err)
|
|
||||||
} else {
|
|
||||||
kernelText = string(d)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
hasSandwich = kbText != "" && kernelText != ""
|
|
||||||
}
|
|
||||||
|
|
||||||
// Run each prompt in the lesson
|
|
||||||
generated := 0
|
|
||||||
for j, prompt := range lesson.Prompts {
|
|
||||||
var messages []ml.Message
|
|
||||||
if lesson.System != "" {
|
|
||||||
messages = append(messages, ml.Message{Role: "system", Content: lesson.System})
|
|
||||||
}
|
|
||||||
|
|
||||||
userContent := prompt.Prompt
|
|
||||||
if hasSandwich {
|
|
||||||
userContent = buildSandwich(kbText, prompt.Prompt, kernelText)
|
|
||||||
}
|
|
||||||
messages = append(messages, ml.Message{Role: "user", Content: userContent})
|
|
||||||
|
|
||||||
slog.Info("sequence: generating",
|
|
||||||
"lesson", lesson.ID,
|
|
||||||
"prompt", fmt.Sprintf("%d/%d", j+1, len(lesson.Prompts)),
|
|
||||||
"id", prompt.ID,
|
|
||||||
)
|
|
||||||
|
|
||||||
response, err := backend.Chat(cmd.Context(), messages, opts)
|
|
||||||
if err != nil {
|
|
||||||
slog.Error("sequence: generation failed",
|
|
||||||
"lesson", lesson.ID,
|
|
||||||
"prompt", prompt.ID,
|
|
||||||
"error", err,
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
record := struct {
|
|
||||||
Messages []ml.Message `json:"messages"`
|
|
||||||
}{
|
|
||||||
Messages: []ml.Message{
|
|
||||||
{Role: "user", Content: userContent},
|
|
||||||
{Role: "assistant", Content: response},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
if err := encoder.Encode(record); err != nil {
|
|
||||||
return fmt.Errorf("write record: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
generated++
|
|
||||||
totalGenerated++
|
|
||||||
}
|
|
||||||
|
|
||||||
// Mark lesson complete
|
|
||||||
state.Completed[lesson.ID] = true
|
|
||||||
state.UpdatedAt = time.Now().Format(time.RFC3339)
|
|
||||||
saveSequenceState(stateFile, state)
|
|
||||||
|
|
||||||
slog.Info("sequence: lesson complete",
|
|
||||||
"id", lesson.ID,
|
|
||||||
"generated", generated,
|
|
||||||
"total", len(lesson.Prompts),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
state.Current = ""
|
|
||||||
state.UpdatedAt = time.Now().Format(time.RFC3339)
|
|
||||||
saveSequenceState(stateFile, state)
|
|
||||||
|
|
||||||
slog.Info("sequence: complete",
|
|
||||||
"id", seq.ID,
|
|
||||||
"output", sequenceOutput,
|
|
||||||
"total_generated", totalGenerated,
|
|
||||||
"lessons_completed", len(state.Completed),
|
|
||||||
"duration", time.Since(start).Round(time.Second),
|
|
||||||
)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func loadSequenceState(path string) sequenceState {
|
|
||||||
data, err := os.ReadFile(path)
|
|
||||||
if err != nil {
|
|
||||||
return sequenceState{}
|
|
||||||
}
|
|
||||||
var state sequenceState
|
|
||||||
json.Unmarshal(data, &state)
|
|
||||||
return state
|
|
||||||
}
|
|
||||||
|
|
||||||
func saveSequenceState(path string, state sequenceState) {
|
|
||||||
data, err := json.MarshalIndent(state, "", " ")
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
os.WriteFile(path, data, 0644)
|
|
||||||
}
|
|
||||||
|
|
@ -1,21 +1,15 @@
|
||||||
package ml
|
package ml
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
|
||||||
"os/signal"
|
|
||||||
"runtime"
|
|
||||||
"sync/atomic"
|
|
||||||
"syscall"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"forge.lthn.ai/core/go/pkg/cli"
|
"forge.lthn.ai/core/go/pkg/cli"
|
||||||
"forge.lthn.ai/core/go-ai/ml"
|
"forge.lthn.ai/core/go/pkg/ml"
|
||||||
)
|
)
|
||||||
|
|
||||||
var serveCmd = &cli.Command{
|
var serveCmd = &cli.Command{
|
||||||
|
|
@ -28,21 +22,11 @@ var serveCmd = &cli.Command{
|
||||||
var (
|
var (
|
||||||
serveBind string
|
serveBind string
|
||||||
serveModelPath string
|
serveModelPath string
|
||||||
serveThreads int
|
|
||||||
serveMaxTokens int
|
|
||||||
serveTimeout int
|
|
||||||
serveMaxRequests int
|
|
||||||
serveMaxContext int
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
serveCmd.Flags().StringVar(&serveBind, "bind", "0.0.0.0:8090", "Address to bind")
|
serveCmd.Flags().StringVar(&serveBind, "bind", "0.0.0.0:8090", "Address to bind")
|
||||||
serveCmd.Flags().StringVar(&serveModelPath, "model-path", "", "Path to model directory (for mlx backend)")
|
serveCmd.Flags().StringVar(&serveModelPath, "model-path", "", "Path to model directory (for mlx backend)")
|
||||||
serveCmd.Flags().IntVar(&serveThreads, "threads", 0, "Max CPU threads (0 = all available)")
|
|
||||||
serveCmd.Flags().IntVar(&serveMaxTokens, "max-tokens", 4096, "Default max tokens per request")
|
|
||||||
serveCmd.Flags().IntVar(&serveTimeout, "timeout", 300, "Request timeout in seconds")
|
|
||||||
serveCmd.Flags().IntVar(&serveMaxRequests, "max-requests", 1, "Max concurrent requests (Metal is single-stream)")
|
|
||||||
serveCmd.Flags().IntVar(&serveMaxContext, "max-context", 4, "Max chat messages to keep (sliding window, 0=unlimited)")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type completionRequest struct {
|
type completionRequest struct {
|
||||||
|
|
@ -50,7 +34,6 @@ type completionRequest struct {
|
||||||
Prompt string `json:"prompt"`
|
Prompt string `json:"prompt"`
|
||||||
MaxTokens int `json:"max_tokens"`
|
MaxTokens int `json:"max_tokens"`
|
||||||
Temperature float64 `json:"temperature"`
|
Temperature float64 `json:"temperature"`
|
||||||
Stream bool `json:"stream"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type completionResponse struct {
|
type completionResponse struct {
|
||||||
|
|
@ -73,7 +56,6 @@ type chatRequest struct {
|
||||||
Messages []ml.Message `json:"messages"`
|
Messages []ml.Message `json:"messages"`
|
||||||
MaxTokens int `json:"max_tokens"`
|
MaxTokens int `json:"max_tokens"`
|
||||||
Temperature float64 `json:"temperature"`
|
Temperature float64 `json:"temperature"`
|
||||||
Stream bool `json:"stream"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type chatResponse struct {
|
type chatResponse struct {
|
||||||
|
|
@ -90,40 +72,6 @@ type chatChoice struct {
|
||||||
FinishReason string `json:"finish_reason"`
|
FinishReason string `json:"finish_reason"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// SSE streaming types (OpenAI chunk format)
|
|
||||||
type chatChunkResponse struct {
|
|
||||||
ID string `json:"id"`
|
|
||||||
Object string `json:"object"`
|
|
||||||
Created int64 `json:"created"`
|
|
||||||
Model string `json:"model"`
|
|
||||||
Choices []chatChunkChoice `json:"choices"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type chatChunkChoice struct {
|
|
||||||
Delta chatChunkDelta `json:"delta"`
|
|
||||||
Index int `json:"index"`
|
|
||||||
FinishReason *string `json:"finish_reason"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type chatChunkDelta struct {
|
|
||||||
Role string `json:"role,omitempty"`
|
|
||||||
Content string `json:"content,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type completionChunkResponse struct {
|
|
||||||
ID string `json:"id"`
|
|
||||||
Object string `json:"object"`
|
|
||||||
Created int64 `json:"created"`
|
|
||||||
Model string `json:"model"`
|
|
||||||
Choices []completionChunkChoice `json:"choices"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type completionChunkChoice struct {
|
|
||||||
Text string `json:"text"`
|
|
||||||
Index int `json:"index"`
|
|
||||||
FinishReason *string `json:"finish_reason"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type usageInfo struct {
|
type usageInfo struct {
|
||||||
PromptTokens int `json:"prompt_tokens"`
|
PromptTokens int `json:"prompt_tokens"`
|
||||||
CompletionTokens int `json:"completion_tokens"`
|
CompletionTokens int `json:"completion_tokens"`
|
||||||
|
|
@ -131,54 +79,16 @@ type usageInfo struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func runServe(cmd *cli.Command, args []string) error {
|
func runServe(cmd *cli.Command, args []string) error {
|
||||||
// Cap CPU threads
|
// Try native MLX backend first (macOS arm64 with mlx tag + model-path set),
|
||||||
if serveThreads > 0 {
|
// fall back to HTTP proxy backend.
|
||||||
prev := runtime.GOMAXPROCS(serveThreads)
|
|
||||||
slog.Info("ml serve: capped threads", "threads", serveThreads, "previous", prev)
|
|
||||||
}
|
|
||||||
|
|
||||||
backend, err := createServeBackend()
|
backend, err := createServeBackend()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if backend supports streaming
|
|
||||||
streamer, canStream := backend.(ml.StreamingBackend)
|
|
||||||
|
|
||||||
// Request tracking
|
|
||||||
var activeRequests atomic.Int32
|
|
||||||
startTime := time.Now()
|
|
||||||
|
|
||||||
mux := http.NewServeMux()
|
mux := http.NewServeMux()
|
||||||
|
|
||||||
// Health endpoint
|
|
||||||
mux.HandleFunc("GET /healthz", func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
json.NewEncoder(w).Encode(map[string]any{
|
|
||||||
"status": "ok",
|
|
||||||
"model": backend.Name(),
|
|
||||||
"uptime_seconds": int(time.Since(startTime).Seconds()),
|
|
||||||
"active_requests": activeRequests.Load(),
|
|
||||||
"max_threads": runtime.GOMAXPROCS(0),
|
|
||||||
"max_tokens": serveMaxTokens,
|
|
||||||
"max_context": serveMaxContext,
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
mux.HandleFunc("POST /v1/completions", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("POST /v1/completions", func(w http.ResponseWriter, r *http.Request) {
|
||||||
// Concurrency gate
|
|
||||||
if int(activeRequests.Load()) >= serveMaxRequests {
|
|
||||||
http.Error(w, `{"error":"server busy, max concurrent requests reached"}`, http.StatusTooManyRequests)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
activeRequests.Add(1)
|
|
||||||
defer activeRequests.Add(-1)
|
|
||||||
|
|
||||||
// Request timeout
|
|
||||||
ctx, cancel := context.WithTimeout(r.Context(), time.Duration(serveTimeout)*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
r = r.WithContext(ctx)
|
|
||||||
|
|
||||||
body, _ := io.ReadAll(r.Body)
|
body, _ := io.ReadAll(r.Body)
|
||||||
var req completionRequest
|
var req completionRequest
|
||||||
if err := json.Unmarshal(body, &req); err != nil {
|
if err := json.Unmarshal(body, &req); err != nil {
|
||||||
|
|
@ -186,67 +96,12 @@ func runServe(cmd *cli.Command, args []string) error {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Enforce server-level max-tokens cap
|
|
||||||
if req.MaxTokens == 0 || req.MaxTokens > serveMaxTokens {
|
|
||||||
req.MaxTokens = serveMaxTokens
|
|
||||||
}
|
|
||||||
|
|
||||||
opts := ml.GenOpts{
|
opts := ml.GenOpts{
|
||||||
Temperature: req.Temperature,
|
Temperature: req.Temperature,
|
||||||
MaxTokens: req.MaxTokens,
|
MaxTokens: req.MaxTokens,
|
||||||
Model: req.Model,
|
Model: req.Model,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Streaming path
|
|
||||||
if req.Stream && canStream {
|
|
||||||
id := fmt.Sprintf("cmpl-%d", time.Now().UnixNano())
|
|
||||||
created := time.Now().Unix()
|
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "text/event-stream")
|
|
||||||
w.Header().Set("Cache-Control", "no-cache")
|
|
||||||
w.Header().Set("Connection", "keep-alive")
|
|
||||||
w.Header().Set("X-Accel-Buffering", "no")
|
|
||||||
flusher, ok := w.(http.Flusher)
|
|
||||||
if !ok {
|
|
||||||
http.Error(w, "streaming not supported", 500)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
err := streamer.GenerateStream(r.Context(), req.Prompt, opts, func(token string) error {
|
|
||||||
chunk := completionChunkResponse{
|
|
||||||
ID: id,
|
|
||||||
Object: "text_completion",
|
|
||||||
Created: created,
|
|
||||||
Model: backend.Name(),
|
|
||||||
Choices: []completionChunkChoice{{Text: token}},
|
|
||||||
}
|
|
||||||
data, _ := json.Marshal(chunk)
|
|
||||||
fmt.Fprintf(w, "data: %s\n\n", data)
|
|
||||||
flusher.Flush()
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
slog.Error("stream error", "err", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Send final chunk with finish_reason
|
|
||||||
stop := "stop"
|
|
||||||
final := completionChunkResponse{
|
|
||||||
ID: id,
|
|
||||||
Object: "text_completion",
|
|
||||||
Created: created,
|
|
||||||
Model: backend.Name(),
|
|
||||||
Choices: []completionChunkChoice{{FinishReason: &stop}},
|
|
||||||
}
|
|
||||||
data, _ := json.Marshal(final)
|
|
||||||
fmt.Fprintf(w, "data: %s\n\n", data)
|
|
||||||
fmt.Fprintf(w, "data: [DONE]\n\n")
|
|
||||||
flusher.Flush()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Non-streaming path
|
|
||||||
text, err := backend.Generate(r.Context(), req.Prompt, opts)
|
text, err := backend.Generate(r.Context(), req.Prompt, opts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, err.Error(), 500)
|
http.Error(w, err.Error(), 500)
|
||||||
|
|
@ -266,19 +121,6 @@ func runServe(cmd *cli.Command, args []string) error {
|
||||||
})
|
})
|
||||||
|
|
||||||
mux.HandleFunc("POST /v1/chat/completions", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("POST /v1/chat/completions", func(w http.ResponseWriter, r *http.Request) {
|
||||||
// Concurrency gate
|
|
||||||
if int(activeRequests.Load()) >= serveMaxRequests {
|
|
||||||
http.Error(w, `{"error":"server busy, max concurrent requests reached"}`, http.StatusTooManyRequests)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
activeRequests.Add(1)
|
|
||||||
defer activeRequests.Add(-1)
|
|
||||||
|
|
||||||
// Request timeout
|
|
||||||
ctx, cancel := context.WithTimeout(r.Context(), time.Duration(serveTimeout)*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
r = r.WithContext(ctx)
|
|
||||||
|
|
||||||
body, _ := io.ReadAll(r.Body)
|
body, _ := io.ReadAll(r.Body)
|
||||||
var req chatRequest
|
var req chatRequest
|
||||||
if err := json.Unmarshal(body, &req); err != nil {
|
if err := json.Unmarshal(body, &req); err != nil {
|
||||||
|
|
@ -286,97 +128,12 @@ func runServe(cmd *cli.Command, args []string) error {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Enforce server-level max-tokens cap
|
|
||||||
if req.MaxTokens == 0 || req.MaxTokens > serveMaxTokens {
|
|
||||||
req.MaxTokens = serveMaxTokens
|
|
||||||
}
|
|
||||||
|
|
||||||
// Sliding window: keep system prompt (if any) + last N messages
|
|
||||||
// Prevents KV-cache explosion on multi-turn conversations
|
|
||||||
if serveMaxContext > 0 && len(req.Messages) > serveMaxContext {
|
|
||||||
var kept []ml.Message
|
|
||||||
rest := req.Messages
|
|
||||||
// Preserve system message if present
|
|
||||||
if len(rest) > 0 && rest[0].Role == "system" {
|
|
||||||
kept = append(kept, rest[0])
|
|
||||||
rest = rest[1:]
|
|
||||||
}
|
|
||||||
// Keep only the last N user/assistant messages
|
|
||||||
if len(rest) > serveMaxContext {
|
|
||||||
rest = rest[len(rest)-serveMaxContext:]
|
|
||||||
}
|
|
||||||
req.Messages = append(kept, rest...)
|
|
||||||
slog.Debug("ml serve: context window applied", "kept", len(req.Messages))
|
|
||||||
}
|
|
||||||
|
|
||||||
opts := ml.GenOpts{
|
opts := ml.GenOpts{
|
||||||
Temperature: req.Temperature,
|
Temperature: req.Temperature,
|
||||||
MaxTokens: req.MaxTokens,
|
MaxTokens: req.MaxTokens,
|
||||||
Model: req.Model,
|
Model: req.Model,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Streaming path
|
|
||||||
if req.Stream && canStream {
|
|
||||||
id := fmt.Sprintf("chatcmpl-%d", time.Now().UnixNano())
|
|
||||||
created := time.Now().Unix()
|
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "text/event-stream")
|
|
||||||
w.Header().Set("Cache-Control", "no-cache")
|
|
||||||
w.Header().Set("Connection", "keep-alive")
|
|
||||||
w.Header().Set("X-Accel-Buffering", "no")
|
|
||||||
flusher, ok := w.(http.Flusher)
|
|
||||||
if !ok {
|
|
||||||
http.Error(w, "streaming not supported", 500)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Send initial role chunk
|
|
||||||
roleChunk := chatChunkResponse{
|
|
||||||
ID: id,
|
|
||||||
Object: "chat.completion.chunk",
|
|
||||||
Created: created,
|
|
||||||
Model: backend.Name(),
|
|
||||||
Choices: []chatChunkChoice{{Delta: chatChunkDelta{Role: "assistant"}}},
|
|
||||||
}
|
|
||||||
data, _ := json.Marshal(roleChunk)
|
|
||||||
fmt.Fprintf(w, "data: %s\n\n", data)
|
|
||||||
flusher.Flush()
|
|
||||||
|
|
||||||
err := streamer.ChatStream(r.Context(), req.Messages, opts, func(token string) error {
|
|
||||||
chunk := chatChunkResponse{
|
|
||||||
ID: id,
|
|
||||||
Object: "chat.completion.chunk",
|
|
||||||
Created: created,
|
|
||||||
Model: backend.Name(),
|
|
||||||
Choices: []chatChunkChoice{{Delta: chatChunkDelta{Content: token}}},
|
|
||||||
}
|
|
||||||
data, _ := json.Marshal(chunk)
|
|
||||||
fmt.Fprintf(w, "data: %s\n\n", data)
|
|
||||||
flusher.Flush()
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
slog.Error("stream error", "err", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Send final chunk with finish_reason
|
|
||||||
stop := "stop"
|
|
||||||
final := chatChunkResponse{
|
|
||||||
ID: id,
|
|
||||||
Object: "chat.completion.chunk",
|
|
||||||
Created: created,
|
|
||||||
Model: backend.Name(),
|
|
||||||
Choices: []chatChunkChoice{{Delta: chatChunkDelta{}, FinishReason: &stop}},
|
|
||||||
}
|
|
||||||
data, _ = json.Marshal(final)
|
|
||||||
fmt.Fprintf(w, "data: %s\n\n", data)
|
|
||||||
fmt.Fprintf(w, "data: [DONE]\n\n")
|
|
||||||
flusher.Flush()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Non-streaming path
|
|
||||||
text, err := backend.Chat(r.Context(), req.Messages, opts)
|
text, err := backend.Chat(r.Context(), req.Messages, opts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, err.Error(), 500)
|
http.Error(w, err.Error(), 500)
|
||||||
|
|
@ -414,59 +171,7 @@ func runServe(cmd *cli.Command, args []string) error {
|
||||||
json.NewEncoder(w).Encode(resp)
|
json.NewEncoder(w).Encode(resp)
|
||||||
})
|
})
|
||||||
|
|
||||||
// Serve the lem-chat UI at root — same origin, no CORS needed
|
slog.Info("ml serve: starting", "bind", serveBind, "backend", backend.Name())
|
||||||
mux.HandleFunc("GET /chat.js", func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
w.Header().Set("Content-Type", "application/javascript")
|
|
||||||
w.Write(lemChatJS)
|
|
||||||
})
|
|
||||||
|
|
||||||
mux.HandleFunc("GET /", func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
if r.URL.Path != "/" {
|
|
||||||
http.NotFound(w, r)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
|
||||||
fmt.Fprintf(w, chatHTML, backend.Name(), serveMaxTokens)
|
|
||||||
})
|
|
||||||
|
|
||||||
slog.Info("ml serve: starting",
|
|
||||||
"bind", serveBind,
|
|
||||||
"backend", backend.Name(),
|
|
||||||
"streaming", canStream,
|
|
||||||
"threads", runtime.GOMAXPROCS(0),
|
|
||||||
"max_tokens", serveMaxTokens,
|
|
||||||
"max_context_msgs", serveMaxContext,
|
|
||||||
"timeout_s", serveTimeout,
|
|
||||||
"max_requests", serveMaxRequests,
|
|
||||||
)
|
|
||||||
fmt.Printf("Serving on http://%s\n", serveBind)
|
fmt.Printf("Serving on http://%s\n", serveBind)
|
||||||
|
return http.ListenAndServe(serveBind, mux)
|
||||||
// Graceful shutdown on SIGINT/SIGTERM
|
|
||||||
srv := &http.Server{
|
|
||||||
Addr: serveBind,
|
|
||||||
Handler: mux,
|
|
||||||
}
|
|
||||||
|
|
||||||
errCh := make(chan error, 1)
|
|
||||||
go func() {
|
|
||||||
errCh <- srv.ListenAndServe()
|
|
||||||
}()
|
|
||||||
|
|
||||||
sigCh := make(chan os.Signal, 1)
|
|
||||||
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
|
|
||||||
|
|
||||||
select {
|
|
||||||
case sig := <-sigCh:
|
|
||||||
slog.Info("ml serve: shutting down", "signal", sig)
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
if err := srv.Shutdown(ctx); err != nil {
|
|
||||||
slog.Error("ml serve: shutdown error", "err", err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
slog.Info("ml serve: stopped cleanly")
|
|
||||||
return nil
|
|
||||||
case err := <-errCh:
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ import (
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
"forge.lthn.ai/core/go/pkg/cli"
|
"forge.lthn.ai/core/go/pkg/cli"
|
||||||
"forge.lthn.ai/core/go-ai/ml"
|
"forge.lthn.ai/core/go/pkg/ml"
|
||||||
)
|
)
|
||||||
|
|
||||||
var statusCmd = &cli.Command{
|
var statusCmd = &cli.Command{
|
||||||
|
|
|
||||||
|
|
@ -1,358 +0,0 @@
|
||||||
//go:build darwin && arm64
|
|
||||||
|
|
||||||
package ml
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bufio"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"log/slog"
|
|
||||||
"os"
|
|
||||||
"runtime"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"forge.lthn.ai/core/go-ai/ml"
|
|
||||||
"forge.lthn.ai/core/go-ai/mlx"
|
|
||||||
"forge.lthn.ai/core/go-ai/mlx/model"
|
|
||||||
"forge.lthn.ai/core/go-ai/mlx/tokenizer"
|
|
||||||
"forge.lthn.ai/core/go/pkg/cli"
|
|
||||||
)
|
|
||||||
|
|
||||||
var trainCmd = &cli.Command{
|
|
||||||
Use: "train",
|
|
||||||
Short: "LoRA fine-tune a model on JSONL training data",
|
|
||||||
Long: `Fine-tunes a local MLX model using LoRA (Low-Rank Adaptation).
|
|
||||||
|
|
||||||
Reads chat-format JSONL training data and trains LoRA adapter weights
|
|
||||||
using AdamW optimiser with cross-entropy loss on assistant tokens only.
|
|
||||||
|
|
||||||
Training data format (one JSON object per line):
|
|
||||||
{"messages": [{"role": "system", "content": "..."}, {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]}`,
|
|
||||||
RunE: runTrain,
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
|
||||||
trainModelPath string
|
|
||||||
trainData string
|
|
||||||
trainOutput string
|
|
||||||
trainRank int
|
|
||||||
trainAlpha float64
|
|
||||||
trainLR float64
|
|
||||||
trainEpochs int
|
|
||||||
trainMaxSeqLen int
|
|
||||||
trainTargets string
|
|
||||||
trainMemoryLimit int
|
|
||||||
)
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
trainCmd.Flags().StringVar(&trainModelPath, "model-path", "", "Path to model directory (required)")
|
|
||||||
trainCmd.Flags().StringVar(&trainData, "data", "", "Training JSONL file (required)")
|
|
||||||
trainCmd.Flags().StringVar(&trainOutput, "output", "adapters.safetensors", "Output adapter file")
|
|
||||||
trainCmd.Flags().IntVar(&trainRank, "rank", 8, "LoRA decomposition rank")
|
|
||||||
trainCmd.Flags().Float64Var(&trainAlpha, "alpha", 16, "LoRA scaling factor")
|
|
||||||
trainCmd.Flags().Float64Var(&trainLR, "lr", 1e-4, "Learning rate")
|
|
||||||
trainCmd.Flags().IntVar(&trainEpochs, "epochs", 1, "Number of training epochs")
|
|
||||||
trainCmd.Flags().IntVar(&trainMaxSeqLen, "max-seq-len", 512, "Maximum sequence length (tokens)")
|
|
||||||
trainCmd.Flags().StringVar(&trainTargets, "targets", "q_proj,v_proj", "Comma-separated projection targets for LoRA")
|
|
||||||
trainCmd.Flags().IntVar(&trainMemoryLimit, "memory-limit", 24, "Metal memory limit in GB")
|
|
||||||
trainCmd.MarkFlagRequired("model-path")
|
|
||||||
trainCmd.MarkFlagRequired("data")
|
|
||||||
}
|
|
||||||
|
|
||||||
// trainSample holds a tokenised training example.
|
|
||||||
type trainSample struct {
|
|
||||||
Tokens []int32 // Full token sequence
|
|
||||||
Mask []int32 // 1 for assistant tokens, 0 for prompt tokens
|
|
||||||
}
|
|
||||||
|
|
||||||
func runTrain(cmd *cli.Command, args []string) error {
|
|
||||||
start := time.Now()
|
|
||||||
|
|
||||||
// --- Load model ---
|
|
||||||
slog.Info("loading model", "path", trainModelPath)
|
|
||||||
m, err := model.LoadModel(trainModelPath)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("load model: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
mlx.SetCacheLimit(uint64(trainMemoryLimit) * 1024 * 1024 * 1024)
|
|
||||||
mlx.SetMemoryLimit(uint64(trainMemoryLimit) * 1024 * 1024 * 1024)
|
|
||||||
|
|
||||||
tok := m.Tokenizer()
|
|
||||||
slog.Info("model loaded",
|
|
||||||
"type", m.ModelType(),
|
|
||||||
"layers", m.NumLayers(),
|
|
||||||
)
|
|
||||||
|
|
||||||
// --- Apply LoRA ---
|
|
||||||
targets := strings.Split(trainTargets, ",")
|
|
||||||
cfg := mlx.LoRAConfig{
|
|
||||||
Rank: trainRank,
|
|
||||||
Alpha: float32(trainAlpha),
|
|
||||||
TargetKeys: targets,
|
|
||||||
}
|
|
||||||
|
|
||||||
adapter := m.ApplyLoRA(cfg)
|
|
||||||
slog.Info("LoRA applied",
|
|
||||||
"rank", cfg.Rank,
|
|
||||||
"alpha", cfg.Alpha,
|
|
||||||
"targets", targets,
|
|
||||||
"trainable_params", adapter.TotalParams(),
|
|
||||||
"layers", len(adapter.Layers),
|
|
||||||
)
|
|
||||||
|
|
||||||
// --- Load training data ---
|
|
||||||
samples, err := loadTrainingSamples(trainData, tok, m.ModelType(), trainMaxSeqLen)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("load training data: %w", err)
|
|
||||||
}
|
|
||||||
slog.Info("training data loaded", "samples", len(samples))
|
|
||||||
|
|
||||||
if len(samples) == 0 {
|
|
||||||
return fmt.Errorf("no training samples loaded")
|
|
||||||
}
|
|
||||||
|
|
||||||
// --- Training loop ---
|
|
||||||
params := adapter.AllTrainableParams()
|
|
||||||
opt := mlx.NewAdamW(trainLR)
|
|
||||||
|
|
||||||
// Build argument indices for ValueAndGrad (all params)
|
|
||||||
argIndices := make([]int, len(params))
|
|
||||||
for i := range argIndices {
|
|
||||||
argIndices[i] = i
|
|
||||||
}
|
|
||||||
|
|
||||||
var totalLoss float64
|
|
||||||
var totalSteps int
|
|
||||||
|
|
||||||
for epoch := 0; epoch < trainEpochs; epoch++ {
|
|
||||||
var epochLoss float64
|
|
||||||
epochStart := time.Now()
|
|
||||||
|
|
||||||
for si, sample := range samples {
|
|
||||||
// Build token tensors: input = tokens[:-1], target = tokens[1:]
|
|
||||||
seqLen := len(sample.Tokens)
|
|
||||||
if seqLen < 2 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
inputTokens := sample.Tokens[:seqLen-1]
|
|
||||||
targetTokens := sample.Tokens[1:]
|
|
||||||
maskTokens := sample.Mask[1:] // mask aligned with targets
|
|
||||||
|
|
||||||
inputArr := mlx.FromValues(inputTokens, 1, len(inputTokens))
|
|
||||||
targetArr := mlx.FromValues(targetTokens, 1, len(targetTokens))
|
|
||||||
|
|
||||||
// Build float32 mask
|
|
||||||
maskF32 := make([]float32, len(maskTokens))
|
|
||||||
for i, m := range maskTokens {
|
|
||||||
maskF32[i] = float32(m)
|
|
||||||
}
|
|
||||||
maskArr := mlx.FromValues(maskF32, 1, len(maskF32))
|
|
||||||
mlx.Materialize(inputArr, targetArr, maskArr)
|
|
||||||
|
|
||||||
// Loss function closure — takes LoRA params as inputs
|
|
||||||
lossFn := func(inputs []*mlx.Array) []*mlx.Array {
|
|
||||||
// Set LoRA params from inputs
|
|
||||||
adapter.SetAllParams(inputs)
|
|
||||||
|
|
||||||
// Forward pass with fresh caches (no KV caching for training)
|
|
||||||
caches := m.NewCache()
|
|
||||||
logits := m.Forward(inputArr, caches)
|
|
||||||
|
|
||||||
// Cast targets to int32 for take_along_axis
|
|
||||||
loss := mlx.MaskedCrossEntropyLoss(logits, targetArr, maskArr)
|
|
||||||
return []*mlx.Array{loss}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Compute value and gradients
|
|
||||||
grad := mlx.ValueAndGrad(lossFn, argIndices...)
|
|
||||||
values, grads, err := grad.Apply(params...)
|
|
||||||
grad.Free()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("epoch %d sample %d: gradient failed: %w", epoch, si, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
mlx.Materialize(append(values, grads...)...)
|
|
||||||
|
|
||||||
loss := values[0].Float()
|
|
||||||
epochLoss += loss
|
|
||||||
totalSteps++
|
|
||||||
|
|
||||||
// Update parameters
|
|
||||||
params = opt.Step(params, grads)
|
|
||||||
adapter.SetAllParams(params)
|
|
||||||
mlx.Materialize(params...)
|
|
||||||
|
|
||||||
// Periodic cleanup
|
|
||||||
if totalSteps%4 == 0 {
|
|
||||||
runtime.GC()
|
|
||||||
mlx.ClearCache()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Log progress
|
|
||||||
if (si+1)%10 == 0 || si == len(samples)-1 {
|
|
||||||
avgLoss := epochLoss / float64(si+1)
|
|
||||||
slog.Info("training",
|
|
||||||
"epoch", epoch+1,
|
|
||||||
"step", fmt.Sprintf("%d/%d", si+1, len(samples)),
|
|
||||||
"loss", fmt.Sprintf("%.4f", loss),
|
|
||||||
"avg_loss", fmt.Sprintf("%.4f", avgLoss),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
totalLoss = epochLoss / float64(len(samples))
|
|
||||||
elapsed := time.Since(epochStart)
|
|
||||||
slog.Info("epoch complete",
|
|
||||||
"epoch", epoch+1,
|
|
||||||
"avg_loss", fmt.Sprintf("%.4f", totalLoss),
|
|
||||||
"duration", elapsed.Round(time.Second),
|
|
||||||
"samples_per_sec", fmt.Sprintf("%.1f", float64(len(samples))/elapsed.Seconds()),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
// --- Save adapter ---
|
|
||||||
if err := adapter.Save(trainOutput); err != nil {
|
|
||||||
return fmt.Errorf("save adapter: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
slog.Info("training complete",
|
|
||||||
"output", trainOutput,
|
|
||||||
"total_steps", totalSteps,
|
|
||||||
"final_loss", fmt.Sprintf("%.4f", totalLoss),
|
|
||||||
"duration", time.Since(start).Round(time.Second),
|
|
||||||
"trainable_params", adapter.TotalParams(),
|
|
||||||
)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// loadTrainingSamples reads JSONL and tokenises each conversation.
|
|
||||||
func loadTrainingSamples(path string, tok *tokenizer.Tokenizer, modelType string, maxSeqLen int) ([]trainSample, error) {
|
|
||||||
f, err := os.Open(path)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer f.Close()
|
|
||||||
|
|
||||||
var samples []trainSample
|
|
||||||
scanner := bufio.NewScanner(f)
|
|
||||||
scanner.Buffer(make([]byte, 1<<20), 1<<20) // 1MB line buffer
|
|
||||||
|
|
||||||
lineNum := 0
|
|
||||||
for scanner.Scan() {
|
|
||||||
lineNum++
|
|
||||||
line := strings.TrimSpace(scanner.Text())
|
|
||||||
if line == "" || strings.HasPrefix(line, "#") {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
var entry struct {
|
|
||||||
Messages []ml.Message `json:"messages"`
|
|
||||||
}
|
|
||||||
if err := json.Unmarshal([]byte(line), &entry); err != nil {
|
|
||||||
slog.Warn("skipping invalid line", "line", lineNum, "error", err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(entry.Messages) == 0 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
sample := tokeniseConversation(entry.Messages, tok, modelType, maxSeqLen)
|
|
||||||
if sample != nil {
|
|
||||||
samples = append(samples, *sample)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return samples, scanner.Err()
|
|
||||||
}
|
|
||||||
|
|
||||||
// tokeniseConversation formats and tokenises a conversation, creating a mask
|
|
||||||
// that is 1 for assistant tokens and 0 for system/user tokens.
|
|
||||||
func tokeniseConversation(messages []ml.Message, tok *tokenizer.Tokenizer, modelType string, maxSeqLen int) *trainSample {
|
|
||||||
// Strategy: tokenise the full conversation, then tokenise just the prefix
|
|
||||||
// (non-assistant parts) to determine the mask boundary.
|
|
||||||
|
|
||||||
// Build full conversation text
|
|
||||||
fullText := formatConversation(messages, modelType, true)
|
|
||||||
fullTokens := tok.Encode(fullText)
|
|
||||||
|
|
||||||
if len(fullTokens) < 2 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Truncate to max sequence length
|
|
||||||
if len(fullTokens) > maxSeqLen {
|
|
||||||
fullTokens = fullTokens[:maxSeqLen]
|
|
||||||
}
|
|
||||||
|
|
||||||
// Build mask: tokenise prefix (everything up to last assistant response)
|
|
||||||
// then mark remaining tokens as assistant (mask=1)
|
|
||||||
prefixText := formatConversation(messages, modelType, false)
|
|
||||||
prefixTokens := tok.Encode(prefixText)
|
|
||||||
|
|
||||||
mask := make([]int32, len(fullTokens))
|
|
||||||
for i := range mask {
|
|
||||||
if i >= len(prefixTokens) {
|
|
||||||
mask[i] = 1 // assistant token
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return &trainSample{
|
|
||||||
Tokens: fullTokens,
|
|
||||||
Mask: mask,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// formatConversation formats messages using the model's chat template.
|
|
||||||
// If includeAssistant is false, only formats up to the last assistant turn header.
|
|
||||||
func formatConversation(messages []ml.Message, modelType string, includeAssistant bool) string {
|
|
||||||
switch modelType {
|
|
||||||
case "qwen3":
|
|
||||||
return formatQwen3Train(messages, includeAssistant)
|
|
||||||
default:
|
|
||||||
return formatGemmaTrain(messages, includeAssistant)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func formatQwen3Train(messages []ml.Message, includeAssistant bool) string {
|
|
||||||
var sb strings.Builder
|
|
||||||
for _, msg := range messages {
|
|
||||||
if msg.Role == "assistant" && !includeAssistant {
|
|
||||||
// Write the assistant header but not the content
|
|
||||||
sb.WriteString("<|im_start|>assistant\n")
|
|
||||||
return sb.String()
|
|
||||||
}
|
|
||||||
switch msg.Role {
|
|
||||||
case "system":
|
|
||||||
sb.WriteString(fmt.Sprintf("<|im_start|>system\n%s<|im_end|>\n", msg.Content))
|
|
||||||
case "user":
|
|
||||||
sb.WriteString(fmt.Sprintf("<|im_start|>user\n%s<|im_end|>\n", msg.Content))
|
|
||||||
case "assistant":
|
|
||||||
sb.WriteString(fmt.Sprintf("<|im_start|>assistant\n%s<|im_end|>\n", msg.Content))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return sb.String()
|
|
||||||
}
|
|
||||||
|
|
||||||
func formatGemmaTrain(messages []ml.Message, includeAssistant bool) string {
|
|
||||||
var sb strings.Builder
|
|
||||||
for _, msg := range messages {
|
|
||||||
if msg.Role == "assistant" && !includeAssistant {
|
|
||||||
sb.WriteString("<start_of_turn>model\n")
|
|
||||||
return sb.String()
|
|
||||||
}
|
|
||||||
switch msg.Role {
|
|
||||||
case "user":
|
|
||||||
sb.WriteString(fmt.Sprintf("<start_of_turn>user\n%s<end_of_turn>\n", msg.Content))
|
|
||||||
case "assistant":
|
|
||||||
sb.WriteString(fmt.Sprintf("<start_of_turn>model\n%s<end_of_turn>\n", msg.Content))
|
|
||||||
case "system":
|
|
||||||
sb.WriteString(fmt.Sprintf("<start_of_turn>user\n[System: %s]<end_of_turn>\n", msg.Content))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return sb.String()
|
|
||||||
}
|
|
||||||
|
|
@ -1,7 +0,0 @@
|
||||||
//go:build darwin && arm64
|
|
||||||
|
|
||||||
package ml
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
mlCmd.AddCommand(trainCmd)
|
|
||||||
}
|
|
||||||
|
|
@ -4,7 +4,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"forge.lthn.ai/core/go/pkg/cli"
|
"forge.lthn.ai/core/go/pkg/cli"
|
||||||
"forge.lthn.ai/core/go-ai/ml"
|
"forge.lthn.ai/core/go/pkg/ml"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,8 @@
|
||||||
//go:build !(darwin && arm64)
|
//go:build !(darwin && arm64 && mlx)
|
||||||
|
|
||||||
package ml
|
package ml
|
||||||
|
|
||||||
import "forge.lthn.ai/core/go-ai/ml"
|
import "forge.lthn.ai/core/go/pkg/ml"
|
||||||
|
|
||||||
func createServeBackend() (ml.Backend, error) {
|
func createServeBackend() (ml.Backend, error) {
|
||||||
return ml.NewHTTPBackend(apiURL, modelName), nil
|
return ml.NewHTTPBackend(apiURL, modelName), nil
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
//go:build darwin && arm64
|
//go:build darwin && arm64 && mlx
|
||||||
|
|
||||||
package ml
|
package ml
|
||||||
|
|
||||||
|
|
@ -6,7 +6,7 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
|
|
||||||
"forge.lthn.ai/core/go-ai/ml"
|
"forge.lthn.ai/core/go/pkg/ml"
|
||||||
)
|
)
|
||||||
|
|
||||||
func createServeBackend() (ml.Backend, error) {
|
func createServeBackend() (ml.Backend, error) {
|
||||||
|
|
|
||||||
|
|
@ -1,59 +0,0 @@
|
||||||
// Package module provides CLI commands for managing marketplace modules.
|
|
||||||
//
|
|
||||||
// Commands:
|
|
||||||
// - install: Install a module from a Git repo
|
|
||||||
// - list: List installed modules
|
|
||||||
// - update: Update a module or all modules
|
|
||||||
// - remove: Remove an installed module
|
|
||||||
package module
|
|
||||||
|
|
||||||
import (
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
|
|
||||||
"forge.lthn.ai/core/go/pkg/cli"
|
|
||||||
"forge.lthn.ai/core/go/pkg/i18n"
|
|
||||||
"forge.lthn.ai/core/go/pkg/marketplace"
|
|
||||||
"forge.lthn.ai/core/go/pkg/store"
|
|
||||||
)
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
cli.RegisterCommands(AddModuleCommands)
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddModuleCommands registers the 'module' command and all subcommands.
|
|
||||||
func AddModuleCommands(root *cli.Command) {
|
|
||||||
moduleCmd := &cli.Command{
|
|
||||||
Use: "module",
|
|
||||||
Short: i18n.T("Manage marketplace modules"),
|
|
||||||
}
|
|
||||||
root.AddCommand(moduleCmd)
|
|
||||||
|
|
||||||
addInstallCommand(moduleCmd)
|
|
||||||
addListCommand(moduleCmd)
|
|
||||||
addUpdateCommand(moduleCmd)
|
|
||||||
addRemoveCommand(moduleCmd)
|
|
||||||
}
|
|
||||||
|
|
||||||
// moduleSetup returns the modules directory, store, and installer.
|
|
||||||
// The caller must defer st.Close().
|
|
||||||
func moduleSetup() (string, *store.Store, *marketplace.Installer, error) {
|
|
||||||
home, err := os.UserHomeDir()
|
|
||||||
if err != nil {
|
|
||||||
return "", nil, nil, cli.Wrap(err, "failed to determine home directory")
|
|
||||||
}
|
|
||||||
|
|
||||||
modulesDir := filepath.Join(home, ".core", "modules")
|
|
||||||
if err := os.MkdirAll(modulesDir, 0755); err != nil {
|
|
||||||
return "", nil, nil, cli.Wrap(err, "failed to create modules directory")
|
|
||||||
}
|
|
||||||
|
|
||||||
dbPath := filepath.Join(modulesDir, "modules.db")
|
|
||||||
st, err := store.New(dbPath)
|
|
||||||
if err != nil {
|
|
||||||
return "", nil, nil, cli.Wrap(err, "failed to open module store")
|
|
||||||
}
|
|
||||||
|
|
||||||
inst := marketplace.NewInstaller(modulesDir, st)
|
|
||||||
return modulesDir, st, inst, nil
|
|
||||||
}
|
|
||||||
|
|
@ -1,59 +0,0 @@
|
||||||
package module
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
|
|
||||||
"forge.lthn.ai/core/go/pkg/cli"
|
|
||||||
"forge.lthn.ai/core/go/pkg/i18n"
|
|
||||||
"forge.lthn.ai/core/go/pkg/marketplace"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
installRepo string
|
|
||||||
installSignKey string
|
|
||||||
)
|
|
||||||
|
|
||||||
func addInstallCommand(parent *cli.Command) {
|
|
||||||
installCmd := cli.NewCommand(
|
|
||||||
"install <code>",
|
|
||||||
i18n.T("Install a module from a Git repo"),
|
|
||||||
i18n.T("Install a module by cloning its Git repository, verifying the manifest signature, and registering it.\n\nThe --repo flag is required and specifies the Git URL to clone from."),
|
|
||||||
func(cmd *cli.Command, args []string) error {
|
|
||||||
if installRepo == "" {
|
|
||||||
return fmt.Errorf("--repo flag is required")
|
|
||||||
}
|
|
||||||
return runInstall(args[0], installRepo, installSignKey)
|
|
||||||
},
|
|
||||||
)
|
|
||||||
installCmd.Args = cli.ExactArgs(1)
|
|
||||||
installCmd.Example = " core module install my-module --repo https://forge.lthn.ai/modules/my-module.git\n core module install signed-mod --repo ssh://git@forge.lthn.ai:2223/modules/signed.git --sign-key abc123"
|
|
||||||
|
|
||||||
cli.StringFlag(installCmd, &installRepo, "repo", "r", "", i18n.T("Git repository URL to clone"))
|
|
||||||
cli.StringFlag(installCmd, &installSignKey, "sign-key", "k", "", i18n.T("Hex-encoded ed25519 public key for manifest verification"))
|
|
||||||
|
|
||||||
parent.AddCommand(installCmd)
|
|
||||||
}
|
|
||||||
|
|
||||||
func runInstall(code, repo, signKey string) error {
|
|
||||||
_, st, inst, err := moduleSetup()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer st.Close()
|
|
||||||
|
|
||||||
cli.Dim("Installing module " + code + " from " + repo + "...")
|
|
||||||
|
|
||||||
mod := marketplace.Module{
|
|
||||||
Code: code,
|
|
||||||
Repo: repo,
|
|
||||||
SignKey: signKey,
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := inst.Install(context.Background(), mod); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
cli.Success("Module " + code + " installed successfully")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
@ -1,51 +0,0 @@
|
||||||
package module
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
|
|
||||||
"forge.lthn.ai/core/go/pkg/cli"
|
|
||||||
"forge.lthn.ai/core/go/pkg/i18n"
|
|
||||||
)
|
|
||||||
|
|
||||||
func addListCommand(parent *cli.Command) {
|
|
||||||
listCmd := cli.NewCommand(
|
|
||||||
"list",
|
|
||||||
i18n.T("List installed modules"),
|
|
||||||
"",
|
|
||||||
func(cmd *cli.Command, args []string) error {
|
|
||||||
return runList()
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
parent.AddCommand(listCmd)
|
|
||||||
}
|
|
||||||
|
|
||||||
func runList() error {
|
|
||||||
_, st, inst, err := moduleSetup()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer st.Close()
|
|
||||||
|
|
||||||
installed, err := inst.Installed()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(installed) == 0 {
|
|
||||||
cli.Dim("No modules installed")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
table := cli.NewTable("Code", "Name", "Version", "Repo")
|
|
||||||
for _, m := range installed {
|
|
||||||
table.AddRow(m.Code, m.Name, m.Version, m.Repo)
|
|
||||||
}
|
|
||||||
|
|
||||||
fmt.Println()
|
|
||||||
table.Render()
|
|
||||||
fmt.Println()
|
|
||||||
cli.Dim(fmt.Sprintf("%d module(s) installed", len(installed)))
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
@ -1,40 +0,0 @@
|
||||||
package module
|
|
||||||
|
|
||||||
import (
|
|
||||||
"forge.lthn.ai/core/go/pkg/cli"
|
|
||||||
"forge.lthn.ai/core/go/pkg/i18n"
|
|
||||||
)
|
|
||||||
|
|
||||||
func addRemoveCommand(parent *cli.Command) {
|
|
||||||
removeCmd := cli.NewCommand(
|
|
||||||
"remove <code>",
|
|
||||||
i18n.T("Remove an installed module"),
|
|
||||||
"",
|
|
||||||
func(cmd *cli.Command, args []string) error {
|
|
||||||
return runRemove(args[0])
|
|
||||||
},
|
|
||||||
)
|
|
||||||
removeCmd.Args = cli.ExactArgs(1)
|
|
||||||
|
|
||||||
parent.AddCommand(removeCmd)
|
|
||||||
}
|
|
||||||
|
|
||||||
func runRemove(code string) error {
|
|
||||||
_, st, inst, err := moduleSetup()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer st.Close()
|
|
||||||
|
|
||||||
if !cli.Confirm("Remove module " + code + "?") {
|
|
||||||
cli.Dim("Cancelled")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := inst.Remove(code); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
cli.Success("Module " + code + " removed")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
@ -1,84 +0,0 @@
|
||||||
package module
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
|
|
||||||
"forge.lthn.ai/core/go/pkg/cli"
|
|
||||||
"forge.lthn.ai/core/go/pkg/i18n"
|
|
||||||
)
|
|
||||||
|
|
||||||
var updateAll bool
|
|
||||||
|
|
||||||
func addUpdateCommand(parent *cli.Command) {
|
|
||||||
updateCmd := cli.NewCommand(
|
|
||||||
"update [code]",
|
|
||||||
i18n.T("Update a module or all modules"),
|
|
||||||
i18n.T("Update a specific module to the latest version, or use --all to update all installed modules."),
|
|
||||||
func(cmd *cli.Command, args []string) error {
|
|
||||||
if updateAll {
|
|
||||||
return runUpdateAll()
|
|
||||||
}
|
|
||||||
if len(args) == 0 {
|
|
||||||
return fmt.Errorf("module code required (or use --all)")
|
|
||||||
}
|
|
||||||
return runUpdate(args[0])
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
cli.BoolFlag(updateCmd, &updateAll, "all", "a", false, i18n.T("Update all installed modules"))
|
|
||||||
|
|
||||||
parent.AddCommand(updateCmd)
|
|
||||||
}
|
|
||||||
|
|
||||||
func runUpdate(code string) error {
|
|
||||||
_, st, inst, err := moduleSetup()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer st.Close()
|
|
||||||
|
|
||||||
cli.Dim("Updating " + code + "...")
|
|
||||||
|
|
||||||
if err := inst.Update(context.Background(), code); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
cli.Success("Module " + code + " updated successfully")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func runUpdateAll() error {
|
|
||||||
_, st, inst, err := moduleSetup()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer st.Close()
|
|
||||||
|
|
||||||
installed, err := inst.Installed()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(installed) == 0 {
|
|
||||||
cli.Dim("No modules installed")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx := context.Background()
|
|
||||||
var updated, failed int
|
|
||||||
for _, m := range installed {
|
|
||||||
cli.Dim("Updating " + m.Code + "...")
|
|
||||||
if err := inst.Update(ctx, m.Code); err != nil {
|
|
||||||
cli.Errorf("Failed to update %s: %v", m.Code, err)
|
|
||||||
failed++
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
cli.Success(m.Code + " updated")
|
|
||||||
updated++
|
|
||||||
}
|
|
||||||
|
|
||||||
fmt.Println()
|
|
||||||
cli.Dim(fmt.Sprintf("%d updated, %d failed", updated, failed))
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
@ -7,7 +7,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"forge.lthn.ai/core/go/pkg/cli"
|
"forge.lthn.ai/core/go/pkg/cli"
|
||||||
"forge.lthn.ai/core/go-devops/infra"
|
"forge.lthn.ai/core/go/pkg/infra"
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"forge.lthn.ai/core/go/pkg/cli"
|
"forge.lthn.ai/core/go/pkg/cli"
|
||||||
"forge.lthn.ai/core/go-devops/infra"
|
"forge.lthn.ai/core/go/pkg/infra"
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"forge.lthn.ai/core/go/pkg/cli"
|
"forge.lthn.ai/core/go/pkg/cli"
|
||||||
"forge.lthn.ai/core/go-devops/infra"
|
"forge.lthn.ai/core/go/pkg/infra"
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -8,9 +8,9 @@ import (
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"forge.lthn.ai/core/go-devops/ansible"
|
"forge.lthn.ai/core/go/pkg/ansible"
|
||||||
"forge.lthn.ai/core/go/pkg/cli"
|
"forge.lthn.ai/core/go/pkg/cli"
|
||||||
"forge.lthn.ai/core/go-devops/infra"
|
"forge.lthn.ai/core/go/pkg/infra"
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ import (
|
||||||
|
|
||||||
"forge.lthn.ai/core/go/pkg/cli"
|
"forge.lthn.ai/core/go/pkg/cli"
|
||||||
"forge.lthn.ai/core/go/pkg/i18n"
|
"forge.lthn.ai/core/go/pkg/i18n"
|
||||||
"forge.lthn.ai/core/go-ai/rag"
|
"forge.lthn.ai/core/go/pkg/rag"
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ import (
|
||||||
|
|
||||||
"forge.lthn.ai/core/go/pkg/cli"
|
"forge.lthn.ai/core/go/pkg/cli"
|
||||||
"forge.lthn.ai/core/go/pkg/i18n"
|
"forge.lthn.ai/core/go/pkg/i18n"
|
||||||
"forge.lthn.ai/core/go-ai/rag"
|
"forge.lthn.ai/core/go/pkg/rag"
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"forge.lthn.ai/core/go/pkg/i18n"
|
"forge.lthn.ai/core/go/pkg/i18n"
|
||||||
"forge.lthn.ai/core/go-ai/rag"
|
"forge.lthn.ai/core/go/pkg/rag"
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"forge.lthn.ai/core/go-ai/ai"
|
"forge.lthn.ai/core/go/pkg/ai"
|
||||||
"forge.lthn.ai/core/go/pkg/cli"
|
"forge.lthn.ai/core/go/pkg/cli"
|
||||||
"forge.lthn.ai/core/go/pkg/i18n"
|
"forge.lthn.ai/core/go/pkg/i18n"
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"forge.lthn.ai/core/go-ai/ai"
|
"forge.lthn.ai/core/go/pkg/ai"
|
||||||
"forge.lthn.ai/core/go/pkg/cli"
|
"forge.lthn.ai/core/go/pkg/cli"
|
||||||
"forge.lthn.ai/core/go/pkg/i18n"
|
"forge.lthn.ai/core/go/pkg/i18n"
|
||||||
)
|
)
|
||||||
|
|
|
||||||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Reference in a new issue