cli/cmd/ml/cmd_serve.go

473 lines
13 KiB
Go
Raw Permalink Normal View History

package ml
import (
"context"
"encoding/json"
"fmt"
"io"
"log/slog"
"net/http"
"os"
"os/signal"
"runtime"
"sync/atomic"
"syscall"
"time"
"forge.lthn.ai/core/go/pkg/cli"
"forge.lthn.ai/core/go-ai/ml"
)
var serveCmd = &cli.Command{
Use: "serve",
Short: "Start OpenAI-compatible inference server",
Long: "Starts an HTTP server serving /v1/completions and /v1/chat/completions using the configured ML backend.",
RunE: runServe,
}
var (
serveBind string
serveModelPath string
serveThreads int
serveMaxTokens int
serveTimeout int
serveMaxRequests int
serveMaxContext int
)
func init() {
serveCmd.Flags().StringVar(&serveBind, "bind", "0.0.0.0:8090", "Address to bind")
serveCmd.Flags().StringVar(&serveModelPath, "model-path", "", "Path to model directory (for mlx backend)")
serveCmd.Flags().IntVar(&serveThreads, "threads", 0, "Max CPU threads (0 = all available)")
serveCmd.Flags().IntVar(&serveMaxTokens, "max-tokens", 4096, "Default max tokens per request")
serveCmd.Flags().IntVar(&serveTimeout, "timeout", 300, "Request timeout in seconds")
serveCmd.Flags().IntVar(&serveMaxRequests, "max-requests", 1, "Max concurrent requests (Metal is single-stream)")
serveCmd.Flags().IntVar(&serveMaxContext, "max-context", 4, "Max chat messages to keep (sliding window, 0=unlimited)")
}
type completionRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
MaxTokens int `json:"max_tokens"`
Temperature float64 `json:"temperature"`
Stream bool `json:"stream"`
}
type completionResponse struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []completionChoice `json:"choices"`
Usage usageInfo `json:"usage"`
}
type completionChoice struct {
Text string `json:"text"`
Index int `json:"index"`
FinishReason string `json:"finish_reason"`
}
type chatRequest struct {
Model string `json:"model"`
Messages []ml.Message `json:"messages"`
MaxTokens int `json:"max_tokens"`
Temperature float64 `json:"temperature"`
Stream bool `json:"stream"`
}
type chatResponse struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []chatChoice `json:"choices"`
}
type chatChoice struct {
Message ml.Message `json:"message"`
Index int `json:"index"`
FinishReason string `json:"finish_reason"`
}
// SSE streaming types (OpenAI chunk format)
type chatChunkResponse struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []chatChunkChoice `json:"choices"`
}
type chatChunkChoice struct {
Delta chatChunkDelta `json:"delta"`
Index int `json:"index"`
FinishReason *string `json:"finish_reason"`
}
type chatChunkDelta struct {
Role string `json:"role,omitempty"`
Content string `json:"content,omitempty"`
}
type completionChunkResponse struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []completionChunkChoice `json:"choices"`
}
type completionChunkChoice struct {
Text string `json:"text"`
Index int `json:"index"`
FinishReason *string `json:"finish_reason"`
}
type usageInfo struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}
func runServe(cmd *cli.Command, args []string) error {
// Cap CPU threads
if serveThreads > 0 {
prev := runtime.GOMAXPROCS(serveThreads)
slog.Info("ml serve: capped threads", "threads", serveThreads, "previous", prev)
}
backend, err := createServeBackend()
if err != nil {
return err
}
// Check if backend supports streaming
streamer, canStream := backend.(ml.StreamingBackend)
// Request tracking
var activeRequests atomic.Int32
startTime := time.Now()
mux := http.NewServeMux()
// Health endpoint
mux.HandleFunc("GET /healthz", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{
"status": "ok",
"model": backend.Name(),
"uptime_seconds": int(time.Since(startTime).Seconds()),
"active_requests": activeRequests.Load(),
"max_threads": runtime.GOMAXPROCS(0),
"max_tokens": serveMaxTokens,
"max_context": serveMaxContext,
})
})
mux.HandleFunc("POST /v1/completions", func(w http.ResponseWriter, r *http.Request) {
// Concurrency gate
if int(activeRequests.Load()) >= serveMaxRequests {
http.Error(w, `{"error":"server busy, max concurrent requests reached"}`, http.StatusTooManyRequests)
return
}
activeRequests.Add(1)
defer activeRequests.Add(-1)
// Request timeout
ctx, cancel := context.WithTimeout(r.Context(), time.Duration(serveTimeout)*time.Second)
defer cancel()
r = r.WithContext(ctx)
body, _ := io.ReadAll(r.Body)
var req completionRequest
if err := json.Unmarshal(body, &req); err != nil {
http.Error(w, err.Error(), 400)
return
}
// Enforce server-level max-tokens cap
if req.MaxTokens == 0 || req.MaxTokens > serveMaxTokens {
req.MaxTokens = serveMaxTokens
}
opts := ml.GenOpts{
Temperature: req.Temperature,
MaxTokens: req.MaxTokens,
Model: req.Model,
}
// Streaming path
if req.Stream && canStream {
id := fmt.Sprintf("cmpl-%d", time.Now().UnixNano())
created := time.Now().Unix()
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
w.Header().Set("X-Accel-Buffering", "no")
flusher, ok := w.(http.Flusher)
if !ok {
http.Error(w, "streaming not supported", 500)
return
}
err := streamer.GenerateStream(r.Context(), req.Prompt, opts, func(token string) error {
chunk := completionChunkResponse{
ID: id,
Object: "text_completion",
Created: created,
Model: backend.Name(),
Choices: []completionChunkChoice{{Text: token}},
}
data, _ := json.Marshal(chunk)
fmt.Fprintf(w, "data: %s\n\n", data)
flusher.Flush()
return nil
})
if err != nil {
slog.Error("stream error", "err", err)
}
// Send final chunk with finish_reason
stop := "stop"
final := completionChunkResponse{
ID: id,
Object: "text_completion",
Created: created,
Model: backend.Name(),
Choices: []completionChunkChoice{{FinishReason: &stop}},
}
data, _ := json.Marshal(final)
fmt.Fprintf(w, "data: %s\n\n", data)
fmt.Fprintf(w, "data: [DONE]\n\n")
flusher.Flush()
return
}
// Non-streaming path
text, err := backend.Generate(r.Context(), req.Prompt, opts)
if err != nil {
http.Error(w, err.Error(), 500)
return
}
resp := completionResponse{
ID: fmt.Sprintf("cmpl-%d", time.Now().UnixNano()),
Object: "text_completion",
Created: time.Now().Unix(),
Model: backend.Name(),
Choices: []completionChoice{{Text: text, FinishReason: "stop"}},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
})
mux.HandleFunc("POST /v1/chat/completions", func(w http.ResponseWriter, r *http.Request) {
// Concurrency gate
if int(activeRequests.Load()) >= serveMaxRequests {
http.Error(w, `{"error":"server busy, max concurrent requests reached"}`, http.StatusTooManyRequests)
return
}
activeRequests.Add(1)
defer activeRequests.Add(-1)
// Request timeout
ctx, cancel := context.WithTimeout(r.Context(), time.Duration(serveTimeout)*time.Second)
defer cancel()
r = r.WithContext(ctx)
body, _ := io.ReadAll(r.Body)
var req chatRequest
if err := json.Unmarshal(body, &req); err != nil {
http.Error(w, err.Error(), 400)
return
}
// Enforce server-level max-tokens cap
if req.MaxTokens == 0 || req.MaxTokens > serveMaxTokens {
req.MaxTokens = serveMaxTokens
}
// Sliding window: keep system prompt (if any) + last N messages
// Prevents KV-cache explosion on multi-turn conversations
if serveMaxContext > 0 && len(req.Messages) > serveMaxContext {
var kept []ml.Message
rest := req.Messages
// Preserve system message if present
if len(rest) > 0 && rest[0].Role == "system" {
kept = append(kept, rest[0])
rest = rest[1:]
}
// Keep only the last N user/assistant messages
if len(rest) > serveMaxContext {
rest = rest[len(rest)-serveMaxContext:]
}
req.Messages = append(kept, rest...)
slog.Debug("ml serve: context window applied", "kept", len(req.Messages))
}
opts := ml.GenOpts{
Temperature: req.Temperature,
MaxTokens: req.MaxTokens,
Model: req.Model,
}
// Streaming path
if req.Stream && canStream {
id := fmt.Sprintf("chatcmpl-%d", time.Now().UnixNano())
created := time.Now().Unix()
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
w.Header().Set("X-Accel-Buffering", "no")
flusher, ok := w.(http.Flusher)
if !ok {
http.Error(w, "streaming not supported", 500)
return
}
// Send initial role chunk
roleChunk := chatChunkResponse{
ID: id,
Object: "chat.completion.chunk",
Created: created,
Model: backend.Name(),
Choices: []chatChunkChoice{{Delta: chatChunkDelta{Role: "assistant"}}},
}
data, _ := json.Marshal(roleChunk)
fmt.Fprintf(w, "data: %s\n\n", data)
flusher.Flush()
err := streamer.ChatStream(r.Context(), req.Messages, opts, func(token string) error {
chunk := chatChunkResponse{
ID: id,
Object: "chat.completion.chunk",
Created: created,
Model: backend.Name(),
Choices: []chatChunkChoice{{Delta: chatChunkDelta{Content: token}}},
}
data, _ := json.Marshal(chunk)
fmt.Fprintf(w, "data: %s\n\n", data)
flusher.Flush()
return nil
})
if err != nil {
slog.Error("stream error", "err", err)
}
// Send final chunk with finish_reason
stop := "stop"
final := chatChunkResponse{
ID: id,
Object: "chat.completion.chunk",
Created: created,
Model: backend.Name(),
Choices: []chatChunkChoice{{Delta: chatChunkDelta{}, FinishReason: &stop}},
}
data, _ = json.Marshal(final)
fmt.Fprintf(w, "data: %s\n\n", data)
fmt.Fprintf(w, "data: [DONE]\n\n")
flusher.Flush()
return
}
// Non-streaming path
text, err := backend.Chat(r.Context(), req.Messages, opts)
if err != nil {
http.Error(w, err.Error(), 500)
return
}
resp := chatResponse{
ID: fmt.Sprintf("chatcmpl-%d", time.Now().UnixNano()),
Object: "chat.completion",
Created: time.Now().Unix(),
Model: backend.Name(),
Choices: []chatChoice{{
Message: ml.Message{Role: "assistant", Content: text},
FinishReason: "stop",
}},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
})
mux.HandleFunc("GET /v1/models", func(w http.ResponseWriter, r *http.Request) {
resp := struct {
Object string `json:"object"`
Data []struct {
ID string `json:"id"`
} `json:"data"`
}{
Object: "list",
Data: []struct {
ID string `json:"id"`
}{{ID: backend.Name()}},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
})
// Serve the lem-chat UI at root — same origin, no CORS needed
mux.HandleFunc("GET /chat.js", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/javascript")
w.Write(lemChatJS)
})
mux.HandleFunc("GET /", func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/" {
http.NotFound(w, r)
return
}
w.Header().Set("Content-Type", "text/html; charset=utf-8")
fmt.Fprintf(w, chatHTML, backend.Name(), serveMaxTokens)
})
slog.Info("ml serve: starting",
"bind", serveBind,
"backend", backend.Name(),
"streaming", canStream,
"threads", runtime.GOMAXPROCS(0),
"max_tokens", serveMaxTokens,
"max_context_msgs", serveMaxContext,
"timeout_s", serveTimeout,
"max_requests", serveMaxRequests,
)
fmt.Printf("Serving on http://%s\n", serveBind)
// Graceful shutdown on SIGINT/SIGTERM
srv := &http.Server{
Addr: serveBind,
Handler: mux,
}
errCh := make(chan error, 1)
go func() {
errCh <- srv.ListenAndServe()
}()
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
select {
case sig := <-sigCh:
slog.Info("ml serve: shutting down", "signal", sig)
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
if err := srv.Shutdown(ctx); err != nil {
slog.Error("ml serve: shutdown error", "err", err)
return err
}
slog.Info("ml serve: stopped cleanly")
return nil
case err := <-errCh:
return err
}
}