diff --git a/pkg/chat/chat.go b/pkg/chat/chat.go index 16df0a8a..efd70c90 100644 --- a/pkg/chat/chat.go +++ b/pkg/chat/chat.go @@ -2,58 +2,14 @@ package chat import ( "bufio" - "context" "io" "slices" "strings" "time" core "dappco.re/go/core" - guimcp "forge.lthn.ai/core/gui/pkg/mcp" ) -type ToolExecutor interface { - Manifest() []guimcp.ToolDescriptor - ManifestText() string - CallTool(ctx context.Context, name string, arguments map[string]any) (string, error) -} - -type ToolCallHandler struct { - executor ToolExecutor -} - -func NewToolCallHandler(executor ToolExecutor) *ToolCallHandler { - return &ToolCallHandler{executor: executor} -} - -func (h *ToolCallHandler) Execute(ctx context.Context, call ToolCall) ToolResult { - if h == nil || h.executor == nil { - return ToolResult{ - ToolCallID: call.ID, - Content: "tool execution unavailable", - } - } - content, err := h.executor.CallTool(ctx, call.Name, call.Arguments) - if err != nil { - return ToolResult{ - ToolCallID: call.ID, - Content: err.Error(), - } - } - return ToolResult{ - ToolCallID: call.ID, - Content: content, - } -} - -func (h *ToolCallHandler) ExecuteAll(ctx context.Context, calls []ToolCall) []ToolResult { - results := make([]ToolResult, 0, len(calls)) - for _, call := range calls { - results = append(results, h.Execute(ctx, call)) - } - return results -} - type StreamCallbacks struct { OnStart func(streamID string) OnToken func(content string) diff --git a/pkg/chat/service.go b/pkg/chat/service.go index 3e801d77..baa3b3a7 100644 --- a/pkg/chat/service.go +++ b/pkg/chat/service.go @@ -64,7 +64,7 @@ type Service struct { store *store.Store httpClient *http.Client toolExecutor ToolExecutor - toolHandler *ToolCallHandler + toolHandler ToolCallHandler pendingAttachments map[string][]ImageAttachment thinkingStates map[string]ThinkingState mu sync.Mutex @@ -227,6 +227,8 @@ func (s *Service) OnStartup(_ context.Context) core.Result { subsystem.RegisterTools(server) s.toolExecutor = subsystem } + registerMCPToolActions(s.Core(), s.toolExecutor) + s.toolExecutor = newActionToolExecutor(s.Core(), s.toolExecutor) s.toolHandler = NewToolCallHandler(s.toolExecutor) s.Core().RegisterQuery(s.handleQuery) s.registerActions() @@ -1078,6 +1080,7 @@ func (s *Service) send(ctx context.Context, input sendInput) (string, error) { if err != nil { return "", err } + assistantMessage = s.withInlineToolCall(conv.ID, assistantMessage) lastAssistantMessageID = assistantMessage.ID if hasRenderableContent(assistantMessage) { conv.Messages = append(conv.Messages, assistantMessage) @@ -1092,9 +1095,19 @@ func (s *Service) send(ctx context.Context, input sendInput) (string, error) { if len(assistantMessage.ToolCalls) == 0 { break } + if s.toolHandler == nil { + break + } - results := s.toolHandler.ExecuteAll(ctx, assistantMessage.ToolCalls) - for _, result := range results { + for _, call := range assistantMessage.ToolCalls { + resultContent, err := s.toolHandler.OnToolCall(ctx, call) + result := ToolResult{ + ToolCallID: call.ID, + Content: renderToolResultContent(resultContent), + } + if err != nil { + result.Content = err.Error() + } toolMessage := ChatMessage{ ID: "tool-" + strconv.FormatInt(s.now().UnixNano(), 36), Role: "tool", @@ -1185,6 +1198,29 @@ func (s *Service) streamAssistant(ctx context.Context, conv Conversation, settin return renderer.Message(messageID, conv.Model, s.now()), nil } +func (s *Service) withInlineToolCall(conversationID string, message ChatMessage) ChatMessage { + if len(message.ToolCalls) > 0 { + return message + } + + call, ok, err := parseInlineToolCall(message.Content) + if err != nil { + _ = s.Core().LogWarn(err, "chat.tool_call", "malformed inline tool_call ignored") + return message + } + if !ok { + return message + } + if call.ID == "" { + call.ID = "call-" + strconv.FormatInt(s.now().UnixNano(), 36) + } + message.Content = "" + message.ToolCalls = []ToolCall{call} + message.FinishReason = "tool_calls" + s.emit(ActionToolCallStarted{ConversationID: conversationID, MessageID: message.ID, Call: call}) + return message +} + func (s *Service) buildCompletionRequest(conv Conversation, settings ChatSettings) openAIRequest { request := openAIRequest{ Model: s.resolveModel(conv.Model, settings.DefaultModel), @@ -1197,13 +1233,18 @@ func (s *Service) buildCompletionRequest(conv Conversation, settings ChatSetting } systemPrompt := strings.TrimSpace(settings.SystemPrompt) - if s.toolExecutor != nil { - manifest := s.toolExecutor.ManifestText() + if s.toolHandler != nil { + manifest := s.toolHandler.BuildToolManifest() if manifest != "" { if systemPrompt != "" { - systemPrompt += "\n\n" + systemPrompt = manifest + "\n\n" + systemPrompt + } else { + systemPrompt = manifest } - systemPrompt += manifest + "\nUse tools when helpful. When a tool is needed, emit a tool call with valid JSON arguments." + } + } + if s.toolExecutor != nil { + if len(s.toolExecutor.Manifest()) > 0 { for _, tool := range s.toolExecutor.Manifest() { request.Tools = append(request.Tools, openAIToolSpec{ Type: "function", diff --git a/pkg/chat/tool_handler.go b/pkg/chat/tool_handler.go new file mode 100644 index 00000000..912c6d95 --- /dev/null +++ b/pkg/chat/tool_handler.go @@ -0,0 +1,236 @@ +package chat + +import ( + "context" + "sort" + "strings" + + core "dappco.re/go/core" + guimcp "forge.lthn.ai/core/gui/pkg/mcp" +) + +const mcpToolActionPrefix = "mcp.tool." + +// ToolExecutor is the chat-facing subset of the GUI MCP subsystem. +type ToolExecutor interface { + Manifest() []guimcp.ToolDescriptor + ManifestText() string + CallTool(ctx context.Context, name string, arguments map[string]any) (string, error) +} + +// ToolCallHandler intercepts model-emitted tool calls and renders the tool +// manifest that is injected into the system prompt. +type ToolCallHandler interface { + OnToolCall(ctx context.Context, call ToolCall) (result any, err error) + BuildToolManifest() string +} + +type mcpToolCallHandler struct { + executor ToolExecutor +} + +func NewToolCallHandler(executor ToolExecutor) ToolCallHandler { + if executor == nil { + return nil + } + return &mcpToolCallHandler{executor: executor} +} + +func (h *mcpToolCallHandler) OnToolCall(ctx context.Context, call ToolCall) (any, error) { + if h == nil || h.executor == nil { + return nil, core.E("chat.tool_call", "tool execution unavailable", nil) + } + call.Name = strings.TrimSpace(call.Name) + if call.Name == "" { + return nil, core.E("chat.tool_call", "tool name is required", nil) + } + if call.Arguments == nil { + call.Arguments = map[string]any{} + } + return h.executor.CallTool(ctx, call.Name, call.Arguments) +} + +func (h *mcpToolCallHandler) BuildToolManifest() string { + if h == nil || h.executor == nil { + return "" + } + + tools := h.executor.Manifest() + if len(tools) == 0 { + return strings.TrimSpace(h.executor.ManifestText()) + } + tools = append([]guimcp.ToolDescriptor(nil), tools...) + sort.Slice(tools, func(i, j int) bool { + return tools[i].Name < tools[j].Name + }) + + var builder strings.Builder + builder.WriteString("Available MCP tools:\n") + for _, tool := range tools { + builder.WriteString("- ") + builder.WriteString(tool.Name) + if strings.TrimSpace(tool.Description) != "" { + builder.WriteString(": ") + builder.WriteString(strings.TrimSpace(tool.Description)) + } + schema := tool.InputSchema + if schema == nil { + schema = map[string]any{"type": "object"} + } + builder.WriteString("\n input_schema: ") + builder.WriteString(jsonString(schema)) + builder.WriteString("\n") + } + builder.WriteString("\nWhen a tool is needed, emit exactly one JSON object in this shape: ") + builder.WriteString(`{"tool_call":{"name":"tool_name","arguments":{}}}`) + builder.WriteString(".") + return strings.TrimSpace(builder.String()) +} + +type actionToolExecutor struct { + core *core.Core + fallback ToolExecutor +} + +func newActionToolExecutor(c *core.Core, fallback ToolExecutor) ToolExecutor { + if c == nil || fallback == nil { + return fallback + } + return &actionToolExecutor{core: c, fallback: fallback} +} + +func registerMCPToolActions(c *core.Core, executor ToolExecutor) { + if c == nil || executor == nil { + return + } + for _, tool := range executor.Manifest() { + name := strings.TrimSpace(tool.Name) + if name == "" { + continue + } + actionName := mcpToolActionPrefix + name + if c.Action(actionName).Exists() { + continue + } + c.Action(actionName, func(ctx context.Context, opts core.Options) core.Result { + content, err := executor.CallTool(ctx, name, toolArgumentsFromOptions(opts)) + return core.Result{}.New(content, err) + }) + } +} + +func (e *actionToolExecutor) Manifest() []guimcp.ToolDescriptor { + if e == nil || e.fallback == nil { + return nil + } + return e.fallback.Manifest() +} + +func (e *actionToolExecutor) ManifestText() string { + if e == nil || e.fallback == nil { + return "" + } + return e.fallback.ManifestText() +} + +func (e *actionToolExecutor) CallTool(ctx context.Context, name string, arguments map[string]any) (string, error) { + if e == nil || e.fallback == nil { + return "", core.E("chat.tool_call", "tool execution unavailable", nil) + } + if e.core != nil { + result := e.core.Action(mcpToolActionPrefix+strings.TrimSpace(name)).Run(ctx, core.NewOptions(core.Option{ + Key: "arguments", + Value: arguments, + })) + if !result.OK { + return "", resultError(result) + } + return renderToolResultContent(result.Value), nil + } + return e.fallback.CallTool(ctx, name, arguments) +} + +type inlineToolCallEnvelope struct { + ToolCall *ToolCall `json:"tool_call"` +} + +func parseInlineToolCall(content string) (ToolCall, bool, error) { + trimmed := strings.TrimSpace(content) + if trimmed == "" || !strings.Contains(trimmed, "tool_call") { + return ToolCall{}, false, nil + } + + var envelope inlineToolCallEnvelope + if result := core.JSONUnmarshal([]byte(trimmed), &envelope); !result.OK { + return ToolCall{}, false, resultError(result) + } + if envelope.ToolCall == nil { + return ToolCall{}, false, nil + } + call := *envelope.ToolCall + call.Name = strings.TrimSpace(call.Name) + if call.Arguments == nil { + call.Arguments = map[string]any{} + } + return call, true, nil +} + +func toolArgumentsFromOptions(opts core.Options) map[string]any { + if value := opts.Get("arguments"); value.OK { + if arguments, ok := value.Value.(map[string]any); ok { + return cloneArguments(arguments) + } + var arguments map[string]any + if result := core.JSONUnmarshal([]byte(jsonString(value.Value)), &arguments); result.OK { + return arguments + } + } + + arguments := make(map[string]any, opts.Len()) + for _, item := range opts.Items() { + arguments[item.Key] = item.Value + } + return arguments +} + +func cloneArguments(arguments map[string]any) map[string]any { + if arguments == nil { + return map[string]any{} + } + clone := make(map[string]any, len(arguments)) + for key, value := range arguments { + clone[key] = value + } + return clone +} + +func renderToolResultContent(result any) string { + switch typed := result.(type) { + case nil: + return "" + case string: + return typed + case []byte: + return string(typed) + default: + return jsonString(typed) + } +} + +func jsonString(value any) string { + result := core.JSONMarshal(value) + if !result.OK { + return "{}" + } + if data, ok := result.Value.([]byte); ok { + return string(data) + } + return "{}" +} + +func resultError(result core.Result) error { + if err, ok := result.Value.(error); ok { + return err + } + return core.E("chat.tool_call", "unexpected result type", nil) +} diff --git a/pkg/chat/tool_handler_example_test.go b/pkg/chat/tool_handler_example_test.go new file mode 100644 index 00000000..d9ca19b6 --- /dev/null +++ b/pkg/chat/tool_handler_example_test.go @@ -0,0 +1,47 @@ +package chat + +import ( + "context" + "fmt" + "strings" + + guimcp "forge.lthn.ai/core/gui/pkg/mcp" +) + +type exampleToolExecutor struct{} + +func (exampleToolExecutor) Manifest() []guimcp.ToolDescriptor { + return []guimcp.ToolDescriptor{{ + Name: "layout_suggest", + Description: "Suggest a layout", + InputSchema: map[string]any{"type": "object"}, + }} +} + +func (exampleToolExecutor) ManifestText() string { + return "Available MCP tools:\n- layout_suggest: Suggest a layout" +} + +func (exampleToolExecutor) CallTool(_ context.Context, name string, _ map[string]any) (string, error) { + if name == "layout_suggest" { + return `{"mode":"left-right"}`, nil + } + return "", nil +} + +func ExampleNewToolCallHandler() { + handler := NewToolCallHandler(exampleToolExecutor{}) + result, err := handler.OnToolCall(context.Background(), ToolCall{ + ID: "call-1", + Name: "layout_suggest", + Arguments: map[string]any{"window_count": 2}, + }) + + fmt.Println(err == nil) + fmt.Println(result) + fmt.Println(strings.Contains(handler.BuildToolManifest(), "layout_suggest")) + // Output: + // true + // {"mode":"left-right"} + // true +} diff --git a/pkg/chat/tool_handler_test.go b/pkg/chat/tool_handler_test.go new file mode 100644 index 00000000..df349c53 --- /dev/null +++ b/pkg/chat/tool_handler_test.go @@ -0,0 +1,184 @@ +package chat + +import ( + "context" + "io" + "net/http" + "strings" + "sync" + "testing" + + core "dappco.re/go/core" + guimcp "forge.lthn.ai/core/gui/pkg/mcp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type strictToolExecutor struct { + mu sync.Mutex + calls []ToolCall +} + +func (m *strictToolExecutor) Manifest() []guimcp.ToolDescriptor { + return []guimcp.ToolDescriptor{{ + Name: "layout_suggest", + Description: "Suggest a layout", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "window_count": map[string]any{"type": "integer"}, + }, + }, + }} +} + +func (m *strictToolExecutor) ManifestText() string { + return "Available MCP tools:\n- layout_suggest: Suggest a layout" +} + +func (m *strictToolExecutor) CallTool(_ context.Context, name string, arguments map[string]any) (string, error) { + m.mu.Lock() + defer m.mu.Unlock() + m.calls = append(m.calls, ToolCall{Name: name, Arguments: arguments}) + if name != "layout_suggest" { + return "", core.E("test.tool", "unknown tool: "+name, nil) + } + return `{"mode":"left-right"}`, nil +} + +func (m *strictToolExecutor) Calls() []ToolCall { + m.mu.Lock() + defer m.mu.Unlock() + return append([]ToolCall(nil), m.calls...) +} + +type completionRecorder struct { + mu sync.Mutex + requests []openAIRequest + responses [][]string +} + +func (r *completionRecorder) ServeHTTP(w http.ResponseWriter, request *http.Request) { + body, _ := io.ReadAll(request.Body) + var completion openAIRequest + if result := core.JSONUnmarshal(body, &completion); !result.OK { + http.Error(w, renderToolResultContent(result.Value), http.StatusBadRequest) + return + } + + r.mu.Lock() + r.requests = append(r.requests, completion) + index := len(r.requests) - 1 + r.mu.Unlock() + + if index >= len(r.responses) { + http.Error(w, "unexpected completion request", http.StatusInternalServerError) + return + } + writeSSE(w, r.responses[index]...) +} + +func (r *completionRecorder) Requests() []openAIRequest { + r.mu.Lock() + defer r.mu.Unlock() + return append([]openAIRequest(nil), r.requests...) +} + +func TestToolCallHandler_Good_ServiceDispatchesInlineToolCall(t *testing.T) { + executor := &strictToolExecutor{} + recorder := &completionRecorder{responses: [][]string{ + { + `{"id":"chatcmpl-1","choices":[{"delta":{"content":"{\"tool_call\":{\"name\":\"layout_suggest\",\"arguments\":{\"window_count\":2}}}"}}]}`, + `{"id":"chatcmpl-1","choices":[{"finish_reason":"stop"}]}`, + `[DONE]`, + }, + { + `{"id":"chatcmpl-2","choices":[{"delta":{"content":"Layout applied"}}]}`, + `{"id":"chatcmpl-2","choices":[{"finish_reason":"stop"}]}`, + `[DONE]`, + }, + }} + c := newChatCore(t, recorder.ServeHTTP, executor) + + send := c.Action("gui.chat.send").Run(context.Background(), core.NewOptions( + core.Option{Key: "content", Value: "Arrange this workspace"}, + )) + require.True(t, send.OK) + + calls := executor.Calls() + require.Len(t, calls, 1) + assert.Equal(t, "layout_suggest", calls[0].Name) + assert.Equal(t, float64(2), calls[0].Arguments["window_count"]) + + conv := latestConversation(t, c) + history := historyMessages(t, c, conv.ID, 0) + require.Len(t, history, 4) + assert.Equal(t, "assistant", history[1].Role) + require.Len(t, history[1].ToolCalls, 1) + assert.Equal(t, "tool", history[2].Role) + assert.Contains(t, history[2].Content, "left-right") + assert.Equal(t, "Layout applied", history[3].Content) + + requests := recorder.Requests() + require.Len(t, requests, 2) + require.NotEmpty(t, requests[0].Messages) + systemPrompt, ok := requests[0].Messages[0].Content.(string) + require.True(t, ok) + assert.True(t, strings.HasPrefix(systemPrompt, "Available MCP tools:")) + assert.Contains(t, systemPrompt, "layout_suggest") + assert.Contains(t, systemPrompt, "You are a helpful assistant.") +} + +func TestToolCallHandler_Bad_UnknownToolErrorAppearsInConversation(t *testing.T) { + executor := &strictToolExecutor{} + recorder := &completionRecorder{responses: [][]string{ + { + `{"id":"chatcmpl-1","choices":[{"delta":{"content":"{\"tool_call\":{\"name\":\"missing_tool\",\"arguments\":{}}}"}}]}`, + `{"id":"chatcmpl-1","choices":[{"finish_reason":"stop"}]}`, + `[DONE]`, + }, + { + `{"id":"chatcmpl-2","choices":[{"delta":{"content":"Could not run that tool"}}]}`, + `{"id":"chatcmpl-2","choices":[{"finish_reason":"stop"}]}`, + `[DONE]`, + }, + }} + c := newChatCore(t, recorder.ServeHTTP, executor) + + send := c.Action("gui.chat.send").Run(context.Background(), core.NewOptions( + core.Option{Key: "content", Value: "Use the missing tool"}, + )) + require.True(t, send.OK) + + conv := latestConversation(t, c) + history := historyMessages(t, c, conv.ID, 0) + require.Len(t, history, 4) + assert.Equal(t, "tool", history[2].Role) + assert.Contains(t, history[2].Content, "missing_tool") + assert.Equal(t, "Could not run that tool", history[3].Content) +} + +func TestToolCallHandler_Ugly_MalformedInlineToolCallDoesNotDispatch(t *testing.T) { + executor := &strictToolExecutor{} + recorder := &completionRecorder{responses: [][]string{{ + `{"id":"chatcmpl-1","choices":[{"delta":{"content":"{\"tool_call\":{\"name\":\"layout_suggest\",\"arguments\":"}}]}`, + `{"id":"chatcmpl-1","choices":[{"finish_reason":"stop"}]}`, + `[DONE]`, + }}} + c := newChatCore(t, recorder.ServeHTTP, executor) + + send := c.Action("gui.chat.send").Run(context.Background(), core.NewOptions( + core.Option{Key: "content", Value: "Try malformed JSON"}, + )) + require.True(t, send.OK) + + assert.Empty(t, executor.Calls()) + assert.Len(t, recorder.Requests(), 1) + + conv := latestConversation(t, c) + history := historyMessages(t, c, conv.ID, 0) + require.Len(t, history, 2) + assert.Equal(t, "assistant", history[1].Role) + assert.Contains(t, history[1].Content, "tool_call") + assert.Empty(t, history[1].ToolCalls) +}