diff --git a/api.go b/api.go index 41f39cc..d875ae8 100644 --- a/api.go +++ b/api.go @@ -38,6 +38,8 @@ type Engine struct { addr string groups []RouteGroup middlewares []gin.HandlerFunc + chatCompletionsResolver *ModelResolver + chatCompletionsPath string cacheTTL time.Duration cacheMaxEntries int cacheMaxBytes int @@ -241,6 +243,12 @@ func (e *Engine) build() *gin.Engine { c.JSON(http.StatusOK, OK("healthy")) }) + // Mount the local OpenAI-compatible chat completion endpoint when configured. + if e.chatCompletionsResolver != nil { + h := newChatCompletionsHandler(e.chatCompletionsResolver) + r.POST(e.chatCompletionsPath, h.ServeHTTP) + } + // Mount each registered group at its base path. for _, g := range e.groups { if isNilRouteGroup(g) { diff --git a/chat_completions.go b/chat_completions.go new file mode 100644 index 0000000..9d9ef66 --- /dev/null +++ b/chat_completions.go @@ -0,0 +1,851 @@ +// SPDX-License-Identifier: EUPL-1.2 + +package api + +import ( + "context" + "encoding/json" + "fmt" + "io" + "math" + "os" + "strconv" + "strings" + "sync" + "time" + "unicode" + + "math/rand" + + "dappco.re/go/core" + inference "dappco.re/go/core/inference" + + "github.com/gin-gonic/gin" + "gopkg.in/yaml.v3" +) + +const defaultChatCompletionsPath = "/v1/chat/completions" + +const ( + chatDefaultTemperature = 1.0 + chatDefaultTopP = 0.95 + chatDefaultTopK = 64 + chatDefaultMaxTokens = 2048 +) + +const channelMarker = "<|channel>" + +// ChatCompletionRequest is the OpenAI-compatible request body. +// +// body := ChatCompletionRequest{ +// Model: "lemer", +// Messages: []ChatMessage{{Role: "user", Content: "What is 2+2?"}}, +// Stream: true, +// } +type ChatCompletionRequest struct { + Model string `json:"model"` + Messages []ChatMessage `json:"messages"` + Temperature *float32 `json:"temperature,omitempty"` + TopP *float32 `json:"top_p,omitempty"` + TopK *int `json:"top_k,omitempty"` + MaxTokens *int `json:"max_tokens,omitempty"` + Stream bool `json:"stream,omitempty"` + Stop []string `json:"stop,omitempty"` + User string `json:"user,omitempty"` +} + +// ChatMessage is a single turn in a conversation. +// +// msg := ChatMessage{Role: "user", Content: "Hello"} +type ChatMessage struct { + Role string `json:"role"` + Content string `json:"content"` +} + +// ChatCompletionResponse is the OpenAI-compatible response body. +// +// resp.Choices[0].Message.Content // "4" +type ChatCompletionResponse struct { + ID string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Model string `json:"model"` + Choices []ChatChoice `json:"choices"` + Usage ChatUsage `json:"usage"` + Thought *string `json:"thought,omitempty"` +} + +// ChatChoice is a single response option. +// +// choice.Message.Content // The generated text +// choice.FinishReason // "stop", "length", or "error" +type ChatChoice struct { + Index int `json:"index"` + Message ChatMessage `json:"message"` + FinishReason string `json:"finish_reason"` +} + +// ChatUsage reports token consumption for the request. +// +// usage.TotalTokens // PromptTokens + CompletionTokens +type ChatUsage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` +} + +// ChatCompletionChunk is a single SSE chunk during streaming. +// +// chunk.Choices[0].Delta.Content // Partial token text +type ChatCompletionChunk struct { + ID string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Model string `json:"model"` + Choices []ChatChunkChoice `json:"choices"` + Thought *string `json:"thought,omitempty"` +} + +// ChatChunkChoice is a streaming delta. +// +// delta.Content // New token(s) in this chunk +type ChatChunkChoice struct { + Index int `json:"index"` + Delta ChatMessageDelta `json:"delta"` + FinishReason *string `json:"finish_reason"` +} + +// ChatMessageDelta is the incremental content within a streaming chunk. +// +// delta.Content // "" on first chunk (role-only), then token text +type ChatMessageDelta struct { + Role string `json:"role,omitempty"` + Content string `json:"content,omitempty"` +} + +type chatCompletionError struct { + Message string `json:"message"` + Type string `json:"type"` + Param string `json:"param,omitempty"` + Code string `json:"code"` +} + +type chatCompletionErrorResponse struct { + Error chatCompletionError `json:"error"` +} + +type modelResolutionError struct { + code string + param string + msg string +} + +func (e *modelResolutionError) Error() string { + if e == nil { + return "" + } + return e.msg +} + +// ModelResolver resolves model names to loaded inference.TextModel instances. +// +// Resolution order: +// +// 1) Exact cache hit +// 2) ~/.core/models.yaml path mapping +// 3) discovery by architecture via inference.Discover() +type ModelResolver struct { + mu sync.RWMutex + loadedByName map[string]inference.TextModel + loadedByPath map[string]inference.TextModel + discovery map[string]string +} + +func NewModelResolver() *ModelResolver { + return &ModelResolver{ + loadedByName: make(map[string]inference.TextModel), + loadedByPath: make(map[string]inference.TextModel), + discovery: make(map[string]string), + } +} + +// ResolveModel maps a model name to a loaded inference.TextModel. +// Cached models are reused. Unknown names return an error. +func (r *ModelResolver) ResolveModel(name string) (inference.TextModel, error) { + if r == nil { + return nil, &modelResolutionError{ + code: "model_not_found", + param: "model", + msg: "model resolver is not configured", + } + } + + requested := core.Lower(strings.TrimSpace(name)) + if requested == "" { + return nil, &modelResolutionError{ + code: "invalid_request_error", + param: "model", + msg: "model is required", + } + } + + r.mu.RLock() + if cached, ok := r.loadedByName[requested]; ok { + r.mu.RUnlock() + return cached, nil + } + r.mu.RUnlock() + + if path, ok := r.lookupModelPath(requested); ok { + return r.loadByPath(requested, path) + } + + if path, ok := r.resolveDiscoveredPath(requested); ok { + return r.loadByPath(requested, path) + } + + return nil, &modelResolutionError{ + code: "model_not_found", + param: "model", + msg: fmt.Sprintf("model %q not found", requested), + } +} + +func (r *ModelResolver) loadByPath(name, path string) (inference.TextModel, error) { + cleanPath := core.Path(path) + r.mu.Lock() + if cached, ok := r.loadedByPath[cleanPath]; ok { + r.loadedByName[name] = cached + r.mu.Unlock() + return cached, nil + } + r.mu.Unlock() + + loaded, err := inference.LoadModel(cleanPath) + if err != nil { + if strings.Contains(strings.ToLower(err.Error()), "loading") { + return nil, &modelResolutionError{ + code: "model_loading", + param: "model", + msg: err.Error(), + } + } + return nil, &modelResolutionError{ + code: "model_not_found", + param: "model", + msg: err.Error(), + } + } + + r.mu.Lock() + r.loadedByName[name] = loaded + r.loadedByPath[cleanPath] = loaded + r.mu.Unlock() + return loaded, nil +} + +func (r *ModelResolver) lookupModelPath(name string) (string, bool) { + mappings, ok := r.modelsYAMLMapping() + if !ok { + return "", false + } + + if path, ok := mappings[name]; ok && strings.TrimSpace(path) != "" { + return path, true + } + return "", false +} + +func (r *ModelResolver) modelsYAMLMapping() (map[string]string, bool) { + configPath := core.Path(core.Env("DIR_HOME"), ".core", "models.yaml") + data, err := os.ReadFile(configPath) + if err != nil { + return nil, false + } + + var content any + if err := yaml.Unmarshal(data, &content); err != nil { + return nil, false + } + + root, ok := content.(map[string]any) + if !ok || root == nil { + return nil, false + } + + normalized := make(map[string]string) + + if models, ok := root["models"].(map[string]any); ok && models != nil { + for key, raw := range models { + if value, ok := raw.(string); ok { + normalized[core.Lower(strings.TrimSpace(key))] = strings.TrimSpace(value) + } + } + } + + for key, raw := range root { + value, ok := raw.(string) + if !ok { + continue + } + normalized[core.Lower(strings.TrimSpace(key))] = strings.TrimSpace(value) + } + + if len(normalized) == 0 { + return nil, false + } + return normalized, true +} + +func (r *ModelResolver) resolveDiscoveredPath(name string) (string, bool) { + candidates := []string{name} + if n := strings.IndexRune(name, ':'); n > 0 { + candidates = append(candidates, name[:n]) + } + + r.mu.RLock() + for _, candidate := range candidates { + if path, ok := r.discovery[candidate]; ok { + r.mu.RUnlock() + return path, true + } + } + r.mu.RUnlock() + + base := core.Path(core.Env("DIR_HOME"), ".core", "models") + var discovered string + for _, m := range discoveryModels(base) { + modelType := strings.ToLower(strings.TrimSpace(m.ModelType)) + for _, candidate := range candidates { + if candidate != "" && candidate == modelType { + discovered = m.Path + break + } + } + if discovered != "" { + break + } + } + + if discovered == "" { + return "", false + } + + r.mu.Lock() + for _, candidate := range candidates { + if candidate != "" { + r.discovery[candidate] = discovered + } + } + r.mu.Unlock() + + return discovered, true +} + +type discoveredModel struct { + Path string +} + +func discoveryModels(base string) []discoveredModel { + var out []discoveredModel + for m := range inference.Discover(base) { + if m.Path == "" || m.ModelType == "" { + continue + } + out = append(out, discoveredModel{Path: m.Path}) + } + return out +} + +// ThinkingExtractor separates thinking channel content from response text. +// Applied as a post-processing step on the token stream. +// +// extractor := NewThinkingExtractor() +// for tok := range model.Chat(ctx, messages) { +// extractor.Process(tok) +// } +// response := extractor.Content() // User-facing text +// thinking := extractor.Thinking() // Internal reasoning (may be nil) +type ThinkingExtractor struct { + currentChannel string + content strings.Builder + thought strings.Builder +} + +func NewThinkingExtractor() *ThinkingExtractor { + return &ThinkingExtractor{ + currentChannel: "assistant", + } +} + +func (te *ThinkingExtractor) Process(token inference.Token) { + te.writeDeltas(token.Text) +} + +func (te *ThinkingExtractor) Content() string { + if te == nil { + return "" + } + return te.content.String() +} + +func (te *ThinkingExtractor) Thinking() *string { + if te == nil { + return nil + } + if te.thought.Len() == 0 { + return nil + } + out := te.thought.String() + return &out +} + +func (te *ThinkingExtractor) writeDeltas(text string) (string, string) { + beforeContentLen := te.content.Len() + beforeThoughtLen := te.thought.Len() + + if te == nil { + return "", "" + } + + remaining := text + for { + next := strings.Index(remaining, channelMarker) + if next < 0 { + te.writeToCurrentChannel(remaining) + break + } + + te.writeToCurrentChannel(remaining[:next]) + remaining = remaining[next+len(channelMarker):] + + remaining = strings.TrimLeftFunc(remaining, unicode.IsSpace) + if remaining == "" { + break + } + + chanName, consumed := parseChannelName(remaining) + if consumed <= 0 { + te.writeToCurrentChannel(channelMarker) + if remaining != "" { + te.writeToCurrentChannel(remaining) + } + break + } + + if chanName == "" { + te.writeToCurrentChannel(channelMarker) + } else { + te.currentChannel = chanName + } + remaining = remaining[consumed:] + } + + return te.content.String()[beforeContentLen:], te.thought.String()[beforeThoughtLen:] +} + +func (te *ThinkingExtractor) writeToCurrentChannel(text string) { + if text == "" { + return + } + + if te.currentChannel == "thought" { + te.thought.WriteString(text) + return + } + te.content.WriteString(text) +} + +func parseChannelName(s string) (string, int) { + if s == "" { + return "", 0 + } + count := 0 + for i := 0; i < len(s); i++ { + c := s[i] + if (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9') || c == '_' || c == '-' { + count++ + continue + } + break + } + if count == 0 { + return "", 0 + } + return strings.ToLower(s[:count]), count +} + +type chatCompletionsHandler struct { + resolver *ModelResolver +} + +func newChatCompletionsHandler(resolver *ModelResolver) *chatCompletionsHandler { + return &chatCompletionsHandler{ + resolver: resolver, + } +} + +func (h *chatCompletionsHandler) ServeHTTP(c *gin.Context) { + if h == nil || h.resolver == nil { + writeChatCompletionError(c, http.StatusServiceUnavailable, "invalid_request_error", "model", "chat handler is not configured", "model") + return + } + + var req ChatCompletionRequest + if err := decodeJSONBody(c.Request.Body, &req); err != nil { + writeChatCompletionError(c, 400, "invalid_request_error", "body", "invalid request body", "") + return + } + + if err := validateChatRequest(&req); err != nil { + chatErr, ok := err.(*chatCompletionRequestError) + if !ok { + writeChatCompletionError(c, 400, "invalid_request_error", "body", err.Error(), "") + return + } + writeChatCompletionError(c, chatErr.Status, chatErr.Code, chatErr.Param, chatErr.Message, chatErr.Type) + return + } + + model, err := h.resolver.ResolveModel(req.Model) + if err != nil { + status, chatErrType, chatErrCode, chatErrParam := mapResolverError(err) + writeChatCompletionError(c, status, "invalid_request_error", chatErrParam, err.Error(), chatErrType) + if chatErrCode != "" { + chatErrType = chatErrCode + } + return + } + + options, err := chatRequestOptions(&req) + if err != nil { + writeChatCompletionError(c, 400, "invalid_request_error", "stop", err.Error(), "") + return + } + + messages := make([]inference.Message, 0, len(req.Messages)) + for _, msg := range req.Messages { + messages = append(messages, inference.Message{ + Role: msg.Role, + Content: msg.Content, + }) + } + + if req.Stream { + h.serveStreaming(c, model, req, messages, options...) + return + } + h.serveNonStreaming(c, model, req, messages, options...) +} + +func (h *chatCompletionsHandler) serveNonStreaming(c *gin.Context, model inference.TextModel, req ChatCompletionRequest, messages []inference.Message, opts ...inference.GenerateOption) { + ctx := c.Request.Context() + created := time.Now().Unix() + completionID := newChatCompletionID() + + extractor := NewThinkingExtractor() + for tok := range model.Chat(ctx, messages, opts...) { + extractor.Process(tok) + } + if err := model.Err(); err != nil { + if strings.Contains(strings.ToLower(err.Error()), "loading") { + writeChatCompletionError(c, http.StatusServiceUnavailable, "model_loading", "model", err.Error(), "") + return + } + writeChatCompletionError(c, http.StatusInternalServerError, "inference_error", "model", err.Error(), "") + return + } + + metrics := model.Metrics() + content := extractor.Content() + finishReason := "stop" + if isTokenLengthCapReached(req.MaxTokens, metrics.GeneratedTokens) { + finishReason = "length" + } + + response := ChatCompletionResponse{ + ID: completionID, + Object: "chat.completion", + Created: created, + Model: req.Model, + Choices: []ChatChoice{ + { + Index: 0, + Message: ChatMessage{ + Role: "assistant", + Content: content, + }, + FinishReason: finishReason, + }, + }, + Usage: ChatUsage{ + PromptTokens: metrics.PromptTokens, + CompletionTokens: metrics.GeneratedTokens, + TotalTokens: metrics.PromptTokens + metrics.GeneratedTokens, + }, + } + if thought := extractor.Thinking(); thought != nil { + response.Thought = thought + } + + c.JSON(http.StatusOK, response) +} + +func (h *chatCompletionsHandler) serveStreaming(c *gin.Context, model inference.TextModel, req ChatCompletionRequest, messages []inference.Message, opts ...inference.GenerateOption) { + ctx := c.Request.Context() + created := time.Now().Unix() + completionID := newChatCompletionID() + + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Status(200) + c.Writer.Flush() + + extractor := NewThinkingExtractor() + chunkFirst := true + sentAny := false + + for tok := range model.Chat(ctx, messages, opts...) { + contentDelta, thoughtDelta := extractor.writeDeltas(tok.Text) + if !chunkFirst && contentDelta == "" && thoughtDelta == "" { + continue + } + + delta := ChatMessageDelta{} + if chunkFirst { + delta.Role = "assistant" + } + delta.Content = contentDelta + + chunk := ChatCompletionChunk{ + ID: completionID, + Object: "chat.completion.chunk", + Created: created, + Model: req.Model, + Choices: []ChatChunkChoice{ + { + Index: 0, + Delta: delta, + FinishReason: nil, + }, + }, + } + if thoughtDelta != "" { + t := thoughtDelta + chunk.Thought = &t + } + + if encoded, encodeErr := json.Marshal(chunk); encodeErr == nil { + c.Writer.WriteString(fmt.Sprintf("data: %s\n\n", encoded)) + c.Writer.Flush() + sentAny = true + } + chunkFirst = false + } + + if err := model.Err(); err != nil && !sentAny { + if strings.Contains(strings.ToLower(err.Error()), "loading") { + writeChatCompletionError(c, http.StatusServiceUnavailable, "model_loading", "model", err.Error(), "") + return + } + writeChatCompletionError(c, http.StatusInternalServerError, "inference_error", "model", err.Error(), "") + return + } + + finishReason := "stop" + metrics := model.Metrics() + if err := model.Err(); err != nil { + finishReason = "error" + } + if finishReason != "error" && isTokenLengthCapReached(req.MaxTokens, metrics.GeneratedTokens) { + finishReason = "length" + } + + finished := finishReason + finalChunk := ChatCompletionChunk{ + ID: completionID, + Object: "chat.completion.chunk", + Created: created, + Model: req.Model, + Choices: []ChatChunkChoice{ + { + Index: 0, + Delta: ChatMessageDelta{}, + FinishReason: &finished, + }, + }, + } + if encoded, encodeErr := json.Marshal(finalChunk); encodeErr == nil { + c.Writer.WriteString(fmt.Sprintf("data: %s\n\n", encoded)) + } + c.Writer.WriteString("data: [DONE]\n\n") + c.Writer.Flush() +} + +type chatCompletionRequestError struct { + Status int + Type string + Code string + Param string + Message string +} + +func (e *chatCompletionRequestError) Error() string { + if e == nil { + return "" + } + return e.Message +} + +func validateChatRequest(req *ChatCompletionRequest) error { + if strings.TrimSpace(req.Model) == "" { + return &chatCompletionRequestError{ + Status: 400, + Type: "invalid_request_error", + Code: "invalid_request_error", + Param: "model", + Message: "model is required", + } + } + + if len(req.Messages) == 0 { + return &chatCompletionRequestError{ + Status: 400, + Type: "invalid_request_error", + Code: "invalid_request_error", + Param: "messages", + Message: "messages must be a non-empty array", + } + } + + for i, msg := range req.Messages { + if strings.TrimSpace(msg.Role) == "" { + return &chatCompletionRequestError{ + Status: 400, + Type: "invalid_request_error", + Code: "invalid_request_error", + Param: fmt.Sprintf("messages[%d].role", i), + Message: "message role is required", + } + } + } + + return nil +} + +func chatRequestOptions(req *ChatCompletionRequest) ([]inference.GenerateOption, error) { + opts := make([]inference.GenerateOption, 0, 5) + opts = append(opts, inference.WithTemperature(chatResolvedFloat(req.Temperature, chatDefaultTemperature))) + opts = append(opts, inference.WithTopP(chatResolvedFloat(req.TopP, chatDefaultTopP))) + opts = append(opts, inference.WithTopK(chatResolvedInt(req.TopK, chatDefaultTopK))) + opts = append(opts, inference.WithMaxTokens(chatResolvedInt(req.MaxTokens, chatDefaultMaxTokens))) + + stops, err := parsedStopTokens(req.Stop) + if err != nil { + return nil, err + } + if len(stops) > 0 { + opts = append(opts, inference.WithStopTokens(stops...)) + } + return opts, nil +} + +func chatResolvedFloat(v *float32, def float32) float32 { + if v == nil { + return def + } + return *v +} + +func chatResolvedInt(v *int, def int) int { + if v == nil { + return def + } + return *v +} + +func parsedStopTokens(stops []string) ([]int32, error) { + if len(stops) == 0 { + return nil, nil + } + + out := make([]int32, 0, len(stops)) + for _, raw := range stops { + raw = strings.TrimSpace(raw) + if raw == "" { + return nil, fmt.Errorf("stop entries cannot be empty") + } + parsed, err := strconv.ParseInt(raw, 10, 32) + if err != nil { + return nil, fmt.Errorf("invalid stop token %q", raw) + } + out = append(out, int32(parsed)) + } + return out, nil +} + +func isTokenLengthCapReached(maxTokens *int, generated int) bool { + if maxTokens == nil { + return false + } + if maxTokens == nil || *maxTokens <= 0 { + return false + } + return generated >= *maxTokens +} + +func mapResolverError(err error) (int, string, string, string) { + resErr, ok := err.(*modelResolutionError) + if !ok { + return 500, "inference_error", "inference_error", "model" + } + switch resErr.code { + case "model_loading": + return http.StatusServiceUnavailable, "model_loading", "model_loading", resErr.param + case "model_not_found": + return 404, "model_not_found", "model_not_found", resErr.param + default: + return 500, "inference_error", "inference_error", resErr.param + } +} + +func writeChatCompletionError(c *gin.Context, status int, errType, param, message, code string) { + if status <= 0 { + status = 500 + } + resp := chatCompletionErrorResponse{ + Error: chatCompletionError{ + Message: message, + Type: errType, + Param: param, + Code: codeOrDefault(code, errType), + }, + } + c.Header("Content-Type", "application/json") + c.JSON(status, resp) + if status == http.StatusServiceUnavailable { + c.Header("Retry-After", "10") + } +} + +func codeOrDefault(code, fallback string) string { + if code != "" { + return code + } + if fallback != "" { + return fallback + } + return "inference_error" +} + +func newChatCompletionID() string { + return fmt.Sprintf("chatcmpl-%d-%06d", time.Now().Unix(), rand.Intn(1_000_000)) +} + +func decodeJSONBody(reader io.Reader, dest any) error { + decoder := json.NewDecoder(reader) + decoder.DisallowUnknownFields() + return decoder.Decode(dest) +} +