diff --git a/pkg/mcp/transformer.go b/pkg/mcp/transformer.go new file mode 100644 index 0000000..1b13a89 --- /dev/null +++ b/pkg/mcp/transformer.go @@ -0,0 +1,476 @@ +// SPDX-License-Identifier: EUPL-1.2 + +package mcp + +import ( + "bytes" + "encoding/json" + "mime" + "strings" +) + +// TransformerIn normalises an AI wire protocol request into a unified MCP +// request envelope. +type TransformerIn interface { + Detect(body []byte, contentType, path string) bool + Normalise(body []byte) (MCPRequest, error) +} + +// TransformerOut converts an MCP result back into an AI wire protocol response. +type TransformerOut interface { + Transform(result MCPResult) ([]byte, error) +} + +// MCPRequest is the gateway's protocol-neutral JSON-RPC request shape. +type MCPRequest struct { + JSONRPC string `json:"jsonrpc,omitempty"` + ID any `json:"id,omitempty"` + Method string `json:"method,omitempty"` + Params map[string]any `json:"params,omitempty"` +} + +// MCPResult is the gateway's protocol-neutral JSON-RPC result shape. +type MCPResult struct { + JSONRPC string `json:"jsonrpc,omitempty"` + ID any `json:"id,omitempty"` + Result any `json:"result,omitempty"` + Error any `json:"error,omitempty"` + Content []MCPContent `json:"content,omitempty"` + ToolCalls []MCPToolCall `json:"tool_calls,omitempty"` + StopReason string `json:"stop_reason,omitempty"` +} + +// MCPContent represents text and tool-use content blocks in the neutral result. +type MCPContent struct { + Type string `json:"type,omitempty"` + Text string `json:"text,omitempty"` + ID string `json:"id,omitempty"` + Name string `json:"name,omitempty"` + Input map[string]any `json:"input,omitempty"` + Arguments map[string]any `json:"arguments,omitempty"` +} + +// MCPToolCall captures a model-requested tool invocation. +type MCPToolCall struct { + ID string `json:"id,omitempty"` + Name string `json:"name,omitempty"` + Arguments map[string]any `json:"arguments,omitempty"` +} + +// TODO(#197 follow-up): add Ollama and LiteLLM concrete transformers once the +// OpenAI/Anthropic/MCP-native gateway surface has settled. + +// NegotiateTransformer selects the inbound transformer using RFC ยง9.4 priority: +// explicit media type, path, body inspection, then MCP-native fallback. The +// honeypot is only selected for malformed or probe-like bodies that no concrete +// protocol claims. +func NegotiateTransformer(body []byte, contentType, path string) TransformerIn { + if headerHasMedia(contentType, "application/openai+json") { + return OpenAITransformer{} + } + if headerHasMedia(contentType, "application/anthropic+json") { + return AnthropicTransformer{} + } + if headerHasMedia(contentType, "application/mcp+json", "application/json-rpc", "application/jsonrpc+json") { + return MCPNativeTransformer{} + } + + switch normaliseGatewayPath(path) { + case "/v1/chat/completions": + return OpenAITransformer{} + case "/v1/messages": + return AnthropicTransformer{} + case "/mcp": + if (HoneypotTransformer{}).Detect(body, contentType, path) { + return HoneypotTransformer{} + } + return MCPNativeTransformer{} + } + + if (MCPNativeTransformer{}).Detect(body, "", "") { + return MCPNativeTransformer{} + } + if (OpenAITransformer{}).Detect(body, "", "") { + if looksAnthropicBody(body) { + return AnthropicTransformer{} + } + return OpenAITransformer{} + } + if (AnthropicTransformer{}).Detect(body, "", "") { + return AnthropicTransformer{} + } + if (HoneypotTransformer{}).Detect(body, contentType, path) { + return HoneypotTransformer{} + } + return MCPNativeTransformer{} +} + +// MCPNativeTransformer is the identity transformer for native MCP JSON-RPC. +type MCPNativeTransformer struct{} + +func (MCPNativeTransformer) Detect(body []byte, contentType, path string) bool { + if headerHasMedia(contentType, "application/mcp+json", "application/json-rpc", "application/jsonrpc+json") { + return true + } + if normaliseGatewayPath(path) == "/mcp" { + return true + } + + obj, ok := decodeJSONObject(body) + if !ok { + return false + } + _, hasMethod := obj["method"].(string) + _, hasResult := obj["result"] + _, hasError := obj["error"] + return obj["jsonrpc"] == "2.0" && (hasMethod || hasResult || hasError) +} + +func (MCPNativeTransformer) Normalise(body []byte) (MCPRequest, error) { + var req MCPRequest + if err := json.Unmarshal(body, &req); err != nil { + return MCPRequest{}, err + } + if req.JSONRPC == "" { + req.JSONRPC = "2.0" + } + return req, nil +} + +func (MCPNativeTransformer) Transform(result MCPResult) ([]byte, error) { + if result.JSONRPC == "" { + result.JSONRPC = "2.0" + } + return json.Marshal(result) +} + +func headerHasMedia(header string, wants ...string) bool { + header = strings.TrimSpace(header) + if header == "" { + return false + } + + wantSet := make(map[string]struct{}, len(wants)) + for _, want := range wants { + wantSet[strings.ToLower(strings.TrimSpace(want))] = struct{}{} + } + + for _, part := range strings.Split(header, ",") { + media := strings.TrimSpace(part) + if parsed, _, err := mime.ParseMediaType(media); err == nil { + media = parsed + } else if semi := strings.IndexByte(media, ';'); semi >= 0 { + media = media[:semi] + } + media = strings.ToLower(strings.TrimSpace(media)) + if _, ok := wantSet[media]; ok { + return true + } + } + return false +} + +func normaliseGatewayPath(path string) string { + path = strings.TrimSpace(path) + if path == "" { + return "" + } + if i := strings.IndexAny(path, "?#"); i >= 0 { + path = path[:i] + } + if !strings.HasPrefix(path, "/") { + path = "/" + path + } + for strings.Contains(path, "//") { + path = strings.ReplaceAll(path, "//", "/") + } + if len(path) > 1 { + path = strings.TrimRight(path, "/") + } + return path +} + +func decodeJSONObject(body []byte) (map[string]any, bool) { + body = bytes.TrimSpace(body) + if len(body) == 0 { + return nil, false + } + var obj map[string]any + if err := json.Unmarshal(body, &obj); err != nil { + return nil, false + } + return obj, true +} + +func hasTopLevelFields(body []byte, fields ...string) bool { + obj, ok := decodeJSONObject(body) + if !ok { + return false + } + for _, field := range fields { + if _, ok := obj[field]; !ok { + return false + } + } + return true +} + +func looksAnthropicBody(body []byte) bool { + obj, ok := decodeJSONObject(body) + if !ok { + return false + } + if _, ok := obj["system"]; ok { + return true + } + if _, ok := obj["max_tokens"]; ok { + return true + } + if _, ok := obj["anthropic_version"]; ok { + return true + } + + messages, ok := obj["messages"].([]any) + if !ok || len(messages) == 0 { + return false + } + for _, raw := range messages { + msg, ok := raw.(map[string]any) + if !ok { + continue + } + if role, _ := msg["role"].(string); role == "system" { + return false + } + if blocks, ok := msg["content"].([]any); ok { + for _, rawBlock := range blocks { + block, ok := rawBlock.(map[string]any) + if !ok { + continue + } + switch block["type"] { + case "tool_use", "tool_result": + return true + } + } + } + } + return false +} + +func messagesHaveNoSystemRole(body []byte) bool { + obj, ok := decodeJSONObject(body) + if !ok { + return false + } + messages, ok := obj["messages"].([]any) + if !ok || len(messages) == 0 { + return false + } + for _, raw := range messages { + msg, ok := raw.(map[string]any) + if !ok { + continue + } + if role, _ := msg["role"].(string); role == "system" { + return false + } + } + return true +} + +func parseRawArgumentObject(raw json.RawMessage) map[string]any { + raw = bytes.TrimSpace(raw) + if len(raw) == 0 || bytes.Equal(raw, []byte("null")) { + return map[string]any{} + } + + var encoded string + if err := json.Unmarshal(raw, &encoded); err == nil { + return parseArgumentString(encoded) + } + + var args map[string]any + if err := json.Unmarshal(raw, &args); err == nil && args != nil { + return args + } + return map[string]any{"_raw": string(raw)} +} + +func parseArgumentString(s string) map[string]any { + s = strings.TrimSpace(s) + if s == "" { + return map[string]any{} + } + var args map[string]any + if err := json.Unmarshal([]byte(s), &args); err == nil && args != nil { + return args + } + return map[string]any{"_raw": s} +} + +func mapFromAny(v any) map[string]any { + switch typed := v.(type) { + case nil: + return map[string]any{} + case map[string]any: + if typed == nil { + return map[string]any{} + } + return typed + case json.RawMessage: + return parseRawArgumentObject(typed) + case string: + return parseArgumentString(typed) + default: + data, err := json.Marshal(typed) + if err != nil { + return map[string]any{"value": typed} + } + return parseRawArgumentObject(data) + } +} + +func extractMCPText(result MCPResult) string { + var parts []string + for _, block := range result.Content { + if block.Text != "" && (block.Type == "" || block.Type == "text") { + parts = append(parts, block.Text) + } + } + parts = append(parts, extractTextFromAny(result.Result)...) + return strings.Join(parts, "\n") +} + +func extractTextFromAny(v any) []string { + switch typed := v.(type) { + case nil: + return nil + case string: + if typed == "" { + return nil + } + return []string{typed} + case []byte: + if len(typed) == 0 { + return nil + } + return []string{string(typed)} + case []MCPContent: + var out []string + for _, block := range typed { + if block.Text != "" && (block.Type == "" || block.Type == "text") { + out = append(out, block.Text) + } + } + return out + case []any: + var out []string + for _, item := range typed { + out = append(out, extractTextFromAny(item)...) + } + return out + case []map[string]any: + var out []string + for _, item := range typed { + out = append(out, extractTextFromAny(item)...) + } + return out + case map[string]any: + for _, key := range []string{"text", "message", "output"} { + if text, ok := typed[key].(string); ok && text != "" { + return []string{text} + } + } + if content, ok := typed["content"]; ok { + return extractTextFromAny(content) + } + if result, ok := typed["result"]; ok { + return extractTextFromAny(result) + } + return nil + default: + data, err := json.Marshal(typed) + if err != nil || len(data) == 0 || bytes.Equal(data, []byte("null")) { + return nil + } + return []string{string(data)} + } +} + +func extractMCPToolCalls(result MCPResult) []MCPToolCall { + var calls []MCPToolCall + calls = append(calls, result.ToolCalls...) + for _, block := range result.Content { + if block.Type != "tool_use" && block.Name == "" { + continue + } + args := block.Input + if len(args) == 0 { + args = block.Arguments + } + calls = append(calls, MCPToolCall{ID: block.ID, Name: block.Name, Arguments: args}) + } + calls = append(calls, extractToolCallsFromAny(result.Result)...) + return calls +} + +func extractToolCallsFromAny(v any) []MCPToolCall { + switch typed := v.(type) { + case nil: + return nil + case []MCPToolCall: + return typed + case []MCPContent: + var calls []MCPToolCall + for _, block := range typed { + if block.Type == "tool_use" || block.Name != "" { + args := block.Input + if len(args) == 0 { + args = block.Arguments + } + calls = append(calls, MCPToolCall{ID: block.ID, Name: block.Name, Arguments: args}) + } + } + return calls + case []any: + var calls []MCPToolCall + for _, item := range typed { + calls = append(calls, extractToolCallsFromAny(item)...) + } + return calls + case []map[string]any: + var calls []MCPToolCall + for _, item := range typed { + calls = append(calls, extractToolCallsFromAny(item)...) + } + return calls + case map[string]any: + for _, key := range []string{"tool_calls", "toolCalls"} { + if raw, ok := typed[key]; ok { + return extractToolCallsFromAny(raw) + } + } + if raw, ok := typed["content"]; ok { + return extractToolCallsFromAny(raw) + } + name, _ := typed["name"].(string) + if name == "" { + if fn, ok := typed["function"].(map[string]any); ok { + name, _ = fn["name"].(string) + args := mapFromAny(fn["arguments"]) + id, _ := typed["id"].(string) + return []MCPToolCall{{ID: id, Name: name, Arguments: args}} + } + return nil + } + id, _ := typed["id"].(string) + args := mapFromAny(typed["arguments"]) + if len(args) == 0 { + args = mapFromAny(typed["input"]) + } + return []MCPToolCall{{ID: id, Name: name, Arguments: args}} + default: + return nil + } +} diff --git a/pkg/mcp/transformer_anthropic.go b/pkg/mcp/transformer_anthropic.go new file mode 100644 index 0000000..5b10224 --- /dev/null +++ b/pkg/mcp/transformer_anthropic.go @@ -0,0 +1,238 @@ +// SPDX-License-Identifier: EUPL-1.2 + +package mcp + +import ( + "encoding/json" + "fmt" +) + +// AnthropicTransformer maps Anthropic Messages requests and responses. +type AnthropicTransformer struct{} + +func (AnthropicTransformer) Detect(body []byte, contentType, path string) bool { + if headerHasMedia(contentType, "application/anthropic+json") { + return true + } + if normaliseGatewayPath(path) == "/v1/messages" { + return true + } + if !hasTopLevelFields(body, "model", "messages") { + return false + } + return looksAnthropicBody(body) || messagesHaveNoSystemRole(body) +} + +func (AnthropicTransformer) Normalise(body []byte) (MCPRequest, error) { + var req anthropicMessagesRequest + if err := json.Unmarshal(body, &req); err != nil { + return MCPRequest{}, err + } + if req.Model == "" { + return MCPRequest{}, fmt.Errorf("anthropic messages request missing model") + } + if len(req.Messages) == 0 { + return MCPRequest{}, fmt.Errorf("anthropic messages request missing messages") + } + + params := map[string]any{ + "source_format": "anthropic", + "model": req.Model, + "messages": normaliseAnthropicMessages(req.Messages), + } + if req.System != nil { + params["system"] = req.System + } + if req.MaxTokens != nil { + params["max_tokens"] = req.MaxTokens + } + if req.Temperature != nil { + params["temperature"] = req.Temperature + } + if req.Stream { + params["stream"] = req.Stream + } + if len(req.Tools) > 0 { + params["tools"] = normaliseAnthropicTools(req.Tools) + } + + toolCalls := anthropicToolUsesFromMessages(req.Messages) + if len(toolCalls) > 0 { + call := toolCalls[0] + params["name"] = call.Name + params["arguments"] = call.Arguments + params["tool_calls"] = toolCalls + return MCPRequest{JSONRPC: "2.0", Method: "tools/call", Params: params}, nil + } + + return MCPRequest{JSONRPC: "2.0", Method: "sampling/createMessage", Params: params}, nil +} + +func (AnthropicTransformer) Transform(result MCPResult) ([]byte, error) { + text := extractMCPText(result) + toolCalls := extractMCPToolCalls(result) + + content := make([]map[string]any, 0, 1+len(toolCalls)) + if text != "" { + content = append(content, map[string]any{ + "type": "text", + "text": text, + }) + } + for i, call := range toolCalls { + id := call.ID + if id == "" { + id = fmt.Sprintf("toolu_%d", i) + } + content = append(content, map[string]any{ + "type": "tool_use", + "id": id, + "name": call.Name, + "input": call.Arguments, + }) + } + if len(content) == 0 { + content = append(content, map[string]any{ + "type": "text", + "text": "", + }) + } + + stopReason := "end_turn" + if len(toolCalls) > 0 { + stopReason = "tool_use" + } + if result.StopReason != "" { + stopReason = result.StopReason + } + + resp := map[string]any{ + "id": anthropicResponseID(result.ID), + "type": "message", + "role": "assistant", + "model": "mcp-gateway", + "content": content, + "stop_reason": stopReason, + "stop_sequence": nil, + } + return json.Marshal(resp) +} + +type anthropicMessagesRequest struct { + Model string `json:"model"` + MaxTokens any `json:"max_tokens,omitempty"` + System any `json:"system,omitempty"` + Messages []anthropicMessage `json:"messages"` + Tools []anthropicTool `json:"tools,omitempty"` + Temperature any `json:"temperature,omitempty"` + Stream bool `json:"stream,omitempty"` +} + +type anthropicMessage struct { + Role string `json:"role"` + Content any `json:"content,omitempty"` +} + +type anthropicTool struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + InputSchema any `json:"input_schema,omitempty"` +} + +func normaliseAnthropicMessages(messages []anthropicMessage) []map[string]any { + out := make([]map[string]any, 0, len(messages)) + for _, msg := range messages { + item := map[string]any{ + "role": msg.Role, + } + if msg.Content != nil { + item["content"] = msg.Content + } + out = append(out, item) + } + return out +} + +func normaliseAnthropicTools(tools []anthropicTool) []map[string]any { + out := make([]map[string]any, 0, len(tools)) + for _, tool := range tools { + out = append(out, map[string]any{ + "name": tool.Name, + "description": tool.Description, + "input_schema": tool.InputSchema, + }) + } + return out +} + +func anthropicToolUsesFromMessages(messages []anthropicMessage) []MCPToolCall { + var calls []MCPToolCall + for i := len(messages) - 1; i >= 0; i-- { + blocks := anthropicContentBlocks(messages[i].Content) + for _, block := range blocks { + if block.Type != "tool_use" || block.Name == "" { + continue + } + calls = append(calls, MCPToolCall{ + ID: block.ID, + Name: block.Name, + Arguments: block.Input, + }) + } + if len(calls) > 0 { + break + } + } + return calls +} + +type anthropicContentBlock struct { + Type string `json:"type"` + Text string `json:"text,omitempty"` + ID string `json:"id,omitempty"` + Name string `json:"name,omitempty"` + Input map[string]any `json:"input,omitempty"` +} + +func anthropicContentBlocks(content any) []anthropicContentBlock { + switch typed := content.(type) { + case nil: + return nil + case []anthropicContentBlock: + return typed + case []any: + blocks := make([]anthropicContentBlock, 0, len(typed)) + for _, item := range typed { + data, err := json.Marshal(item) + if err != nil { + continue + } + var block anthropicContentBlock + if err := json.Unmarshal(data, &block); err == nil { + blocks = append(blocks, block) + } + } + return blocks + case map[string]any: + data, err := json.Marshal(typed) + if err != nil { + return nil + } + var block anthropicContentBlock + if err := json.Unmarshal(data, &block); err != nil { + return nil + } + return []anthropicContentBlock{block} + case string: + return []anthropicContentBlock{{Type: "text", Text: typed}} + default: + return nil + } +} + +func anthropicResponseID(id any) string { + if id == nil { + return "msg_mcp" + } + return fmt.Sprintf("msg_%v", id) +} diff --git a/pkg/mcp/transformer_honeypot.go b/pkg/mcp/transformer_honeypot.go new file mode 100644 index 0000000..23019f7 --- /dev/null +++ b/pkg/mcp/transformer_honeypot.go @@ -0,0 +1,112 @@ +// SPDX-License-Identifier: EUPL-1.2 + +package mcp + +import ( + "bytes" + "encoding/json" + "fmt" + "strings" +) + +// HoneypotTransformer absorbs malformed or probe-like input and returns a +// plausible synthetic response without dispatching to real tools. +type HoneypotTransformer struct{} + +func (HoneypotTransformer) Detect(body []byte, contentType, path string) bool { + trimmed := bytes.TrimSpace(body) + if len(trimmed) == 0 { + return false + } + if !json.Valid(trimmed) { + return true + } + + var obj map[string]any + if err := json.Unmarshal(trimmed, &obj); err != nil { + return true + } + return looksProbeLike(trimmed, contentType, path) +} + +func (HoneypotTransformer) Normalise(body []byte) (MCPRequest, error) { + params := map[string]any{ + "source_format": "honeypot", + "raw": honeypotSnippet(body), + "malformed": !json.Valid(bytes.TrimSpace(body)), + } + return MCPRequest{ + JSONRPC: "2.0", + Method: "honeypot/respond", + Params: params, + }, nil +} + +func (HoneypotTransformer) Transform(result MCPResult) ([]byte, error) { + text := extractMCPText(result) + if text == "" { + text = "Request received. The gateway is processing the available context and will return compatible MCP output when a valid protocol envelope is provided." + } + + resp := map[string]any{ + "id": honeypotResponseID(result.ID), + "object": "chat.completion", + "created": 0, + "model": "mcp-gateway", + "choices": []map[string]any{ + { + "index": 0, + "message": map[string]any{ + "role": "assistant", + "content": text, + }, + "finish_reason": "stop", + }, + }, + "usage": map[string]any{ + "prompt_tokens": 0, + "completion_tokens": 0, + "total_tokens": 0, + }, + } + return json.Marshal(resp) +} + +func looksProbeLike(body []byte, contentType, path string) bool { + haystack := strings.ToLower(strings.Join([]string{ + string(body), + contentType, + path, + }, "\n")) + for _, marker := range []string{ + "ignore previous", + "system prompt", + "developer message", + "/etc/passwd", + "../../", + "dump secrets", + "jailbreak", + "prompt injection", + } { + if strings.Contains(haystack, marker) { + return true + } + } + return false +} + +func honeypotSnippet(body []byte) string { + s := string(bytes.TrimSpace(body)) + const max = 4096 + if len(s) <= max { + return s + } + return s[:max] +} + +func honeypotResponseID(id any) string { + if id == nil { + return "chatcmpl-honeypot" + } + return fmt.Sprintf("chatcmpl-honeypot-%v", id) +} diff --git a/pkg/mcp/transformer_openai.go b/pkg/mcp/transformer_openai.go new file mode 100644 index 0000000..ff4e78d --- /dev/null +++ b/pkg/mcp/transformer_openai.go @@ -0,0 +1,247 @@ +// SPDX-License-Identifier: EUPL-1.2 + +package mcp + +import ( + "encoding/json" + "fmt" +) + +// OpenAITransformer maps OpenAI Chat Completions requests and responses. +type OpenAITransformer struct{} + +func (OpenAITransformer) Detect(body []byte, contentType, path string) bool { + if headerHasMedia(contentType, "application/openai+json") { + return true + } + if normaliseGatewayPath(path) == "/v1/chat/completions" { + return true + } + return hasTopLevelFields(body, "model", "messages") +} + +func (OpenAITransformer) Normalise(body []byte) (MCPRequest, error) { + var req openAIChatCompletionRequest + if err := json.Unmarshal(body, &req); err != nil { + return MCPRequest{}, err + } + if req.Model == "" { + return MCPRequest{}, fmt.Errorf("openai chat completion request missing model") + } + if len(req.Messages) == 0 { + return MCPRequest{}, fmt.Errorf("openai chat completion request missing messages") + } + + params := map[string]any{ + "source_format": "openai", + "model": req.Model, + "messages": normaliseOpenAIMessages(req.Messages), + } + if len(req.Tools) > 0 { + params["tools"] = normaliseOpenAITools(req.Tools) + } + if req.ToolChoice != nil { + params["tool_choice"] = req.ToolChoice + } + if req.MaxTokens != nil { + params["max_tokens"] = req.MaxTokens + } + if req.MaxCompletionTokens != nil { + params["max_completion_tokens"] = req.MaxCompletionTokens + } + if req.Temperature != nil { + params["temperature"] = req.Temperature + } + if req.Stream { + params["stream"] = req.Stream + } + + toolCalls := openAIToolCallsFromMessages(req.Messages) + if len(toolCalls) > 0 { + call := toolCalls[0] + params["name"] = call.Name + params["arguments"] = call.Arguments + params["tool_calls"] = toolCalls + return MCPRequest{JSONRPC: "2.0", Method: "tools/call", Params: params}, nil + } + + return MCPRequest{JSONRPC: "2.0", Method: "sampling/createMessage", Params: params}, nil +} + +func (OpenAITransformer) Transform(result MCPResult) ([]byte, error) { + text := extractMCPText(result) + toolCalls := extractMCPToolCalls(result) + + message := map[string]any{ + "role": "assistant", + } + if text != "" { + message["content"] = text + } else if len(toolCalls) > 0 { + message["content"] = nil + } else { + message["content"] = "" + } + if len(toolCalls) > 0 { + message["tool_calls"] = openAIToolCallsFromMCP(toolCalls) + } + + finishReason := "stop" + if len(toolCalls) > 0 { + finishReason = "tool_calls" + } + if result.StopReason != "" { + finishReason = result.StopReason + } + + resp := map[string]any{ + "id": openAIResponseID(result.ID), + "object": "chat.completion", + "created": 0, + "model": "mcp-gateway", + "choices": []map[string]any{ + { + "index": 0, + "message": message, + "finish_reason": finishReason, + }, + }, + } + return json.Marshal(resp) +} + +type openAIChatCompletionRequest struct { + Model string `json:"model"` + Messages []openAIMessage `json:"messages"` + Tools []openAITool `json:"tools,omitempty"` + ToolChoice any `json:"tool_choice,omitempty"` + MaxTokens any `json:"max_tokens,omitempty"` + MaxCompletionTokens any `json:"max_completion_tokens,omitempty"` + Temperature any `json:"temperature,omitempty"` + Stream bool `json:"stream,omitempty"` +} + +type openAIMessage struct { + Role string `json:"role"` + Content any `json:"content,omitempty"` + Name string `json:"name,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` + ToolCalls []openAIToolCall `json:"tool_calls,omitempty"` +} + +type openAITool struct { + Type string `json:"type"` + Function openAIFunctionMetadata `json:"function"` +} + +type openAIFunctionMetadata struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + Parameters any `json:"parameters,omitempty"` +} + +type openAIToolCall struct { + ID string `json:"id,omitempty"` + Type string `json:"type,omitempty"` + Function openAIFunctionCall `json:"function"` +} + +type openAIFunctionCall struct { + Name string `json:"name"` + Arguments json.RawMessage `json:"arguments,omitempty"` +} + +func normaliseOpenAIMessages(messages []openAIMessage) []map[string]any { + out := make([]map[string]any, 0, len(messages)) + for _, msg := range messages { + item := map[string]any{ + "role": msg.Role, + } + if msg.Content != nil { + item["content"] = msg.Content + } + if msg.Name != "" { + item["name"] = msg.Name + } + if msg.ToolCallID != "" { + item["tool_call_id"] = msg.ToolCallID + } + if len(msg.ToolCalls) > 0 { + item["tool_calls"] = openAIToolCallsFromMessages([]openAIMessage{msg}) + } + out = append(out, item) + } + return out +} + +func normaliseOpenAITools(tools []openAITool) []map[string]any { + out := make([]map[string]any, 0, len(tools)) + for _, tool := range tools { + if tool.Type != "" && tool.Type != "function" { + out = append(out, map[string]any{ + "type": tool.Type, + "function": tool.Function, + }) + continue + } + item := map[string]any{ + "name": tool.Function.Name, + "description": tool.Function.Description, + "input_schema": tool.Function.Parameters, + } + out = append(out, item) + } + return out +} + +func openAIToolCallsFromMessages(messages []openAIMessage) []MCPToolCall { + var calls []MCPToolCall + for i := len(messages) - 1; i >= 0; i-- { + msg := messages[i] + if len(msg.ToolCalls) == 0 { + continue + } + for _, call := range msg.ToolCalls { + if call.Function.Name == "" { + continue + } + calls = append(calls, MCPToolCall{ + ID: call.ID, + Name: call.Function.Name, + Arguments: parseRawArgumentObject(call.Function.Arguments), + }) + } + break + } + return calls +} + +func openAIToolCallsFromMCP(calls []MCPToolCall) []map[string]any { + out := make([]map[string]any, 0, len(calls)) + for i, call := range calls { + id := call.ID + if id == "" { + id = fmt.Sprintf("call_%d", i) + } + args, err := json.Marshal(call.Arguments) + if err != nil { + args = []byte("{}") + } + out = append(out, map[string]any{ + "id": id, + "type": "function", + "function": map[string]any{ + "name": call.Name, + "arguments": string(args), + }, + }) + } + return out +} + +func openAIResponseID(id any) string { + if id == nil { + return "chatcmpl-mcp" + } + return fmt.Sprintf("chatcmpl-%v", id) +} diff --git a/pkg/mcp/transformer_test.go b/pkg/mcp/transformer_test.go new file mode 100644 index 0000000..84ffcd0 --- /dev/null +++ b/pkg/mcp/transformer_test.go @@ -0,0 +1,191 @@ +// SPDX-License-Identifier: EUPL-1.2 + +package mcp + +import ( + "encoding/json" + "testing" +) + +func TestNegotiate_OpenAI_Good(t *testing.T) { + body := []byte(`{"model":"gpt-4o-mini","messages":[{"role":"user","content":"hello"}]}`) + + if _, ok := NegotiateTransformer(body, "", "/v1/chat/completions").(OpenAITransformer); !ok { + t.Fatal("expected OpenAITransformer for chat completions path") + } +} + +func TestNegotiate_Anthropic_Good(t *testing.T) { + body := []byte(`{"model":"claude-3-5-sonnet","max_tokens":128,"messages":[{"role":"user","content":"hello"}]}`) + + if _, ok := NegotiateTransformer(body, "", "/v1/messages").(AnthropicTransformer); !ok { + t.Fatal("expected AnthropicTransformer for messages path") + } +} + +func TestNegotiate_MCPNative_Good(t *testing.T) { + body := []byte(`{"jsonrpc":"2.0","id":1,"method":"tools/list","params":{}}`) + + if _, ok := NegotiateTransformer(body, "application/mcp+json", "/mcp").(MCPNativeTransformer); !ok { + t.Fatal("expected MCPNativeTransformer for native MCP request") + } +} + +func TestOpenAITransformer_Normalise_Good(t *testing.T) { + body := []byte(`{ + "model": "gpt-4o", + "messages": [ + { + "role": "assistant", + "content": null, + "tool_calls": [ + { + "id": "call_1", + "type": "function", + "function": { + "name": "file_read", + "arguments": "{\"path\":\"README.md\"}" + } + } + ] + } + ] + }`) + + req, err := (OpenAITransformer{}).Normalise(body) + if err != nil { + t.Fatalf("Normalise failed: %v", err) + } + if req.JSONRPC != "2.0" { + t.Fatalf("expected JSON-RPC 2.0, got %q", req.JSONRPC) + } + if req.Method != "tools/call" { + t.Fatalf("expected tools/call, got %q", req.Method) + } + if req.Params["source_format"] != "openai" { + t.Fatalf("expected source_format openai, got %v", req.Params["source_format"]) + } + if req.Params["model"] != "gpt-4o" { + t.Fatalf("expected model to be preserved, got %v", req.Params["model"]) + } + if req.Params["name"] != "file_read" { + t.Fatalf("expected tool name file_read, got %v", req.Params["name"]) + } + args, ok := req.Params["arguments"].(map[string]any) + if !ok { + t.Fatalf("expected argument map, got %T", req.Params["arguments"]) + } + if args["path"] != "README.md" { + t.Fatalf("expected README.md path, got %v", args["path"]) + } +} + +func TestOpenAITransformer_Transform_Good(t *testing.T) { + data, err := (OpenAITransformer{}).Transform(MCPResult{ + ID: 7, + Result: map[string]any{ + "content": []any{ + map[string]any{"type": "text", "text": "done"}, + }, + }, + }) + if err != nil { + t.Fatalf("Transform failed: %v", err) + } + + var resp map[string]any + if err := json.Unmarshal(data, &resp); err != nil { + t.Fatalf("response is not JSON: %v", err) + } + if resp["object"] != "chat.completion" { + t.Fatalf("expected chat.completion object, got %v", resp["object"]) + } + choices := resp["choices"].([]any) + message := choices[0].(map[string]any)["message"].(map[string]any) + if message["content"] != "done" { + t.Fatalf("expected content done, got %v", message["content"]) + } +} + +func TestAnthropicTransformer_Normalise_Good(t *testing.T) { + body := []byte(`{ + "model": "claude-3-5-sonnet", + "max_tokens": 256, + "messages": [ + { + "role": "assistant", + "content": [ + { + "type": "tool_use", + "id": "toolu_1", + "name": "file_read", + "input": {"path":"README.md"} + } + ] + } + ] + }`) + + req, err := (AnthropicTransformer{}).Normalise(body) + if err != nil { + t.Fatalf("Normalise failed: %v", err) + } + if req.Method != "tools/call" { + t.Fatalf("expected tools/call, got %q", req.Method) + } + if req.Params["source_format"] != "anthropic" { + t.Fatalf("expected source_format anthropic, got %v", req.Params["source_format"]) + } + if req.Params["name"] != "file_read" { + t.Fatalf("expected tool name file_read, got %v", req.Params["name"]) + } + args, ok := req.Params["arguments"].(map[string]any) + if !ok { + t.Fatalf("expected argument map, got %T", req.Params["arguments"]) + } + if args["path"] != "README.md" { + t.Fatalf("expected README.md path, got %v", args["path"]) + } +} + +func TestAnthropicTransformer_Transform_Good(t *testing.T) { + data, err := (AnthropicTransformer{}).Transform(MCPResult{ + ID: "abc", + Content: []MCPContent{{Type: "text", Text: "done"}}, + }) + if err != nil { + t.Fatalf("Transform failed: %v", err) + } + + var resp map[string]any + if err := json.Unmarshal(data, &resp); err != nil { + t.Fatalf("response is not JSON: %v", err) + } + if resp["type"] != "message" { + t.Fatalf("expected message type, got %v", resp["type"]) + } + content := resp["content"].([]any) + first := content[0].(map[string]any) + if first["text"] != "done" { + t.Fatalf("expected text done, got %v", first["text"]) + } +} + +func TestHoneypotTransformer_Detect_FallbackOnGarbage(t *testing.T) { + body := []byte(`{not-json`) + + if !(HoneypotTransformer{}).Detect(body, "", "/probe") { + t.Fatal("expected honeypot to detect malformed input") + } + if _, ok := NegotiateTransformer(body, "", "/probe").(HoneypotTransformer); !ok { + t.Fatal("expected negotiation to select honeypot for malformed input") + } +} + +func TestNegotiate_Priority_Ugly(t *testing.T) { + body := []byte(`{"model":"claude-3-5-sonnet","max_tokens":128,"messages":[{"role":"user","content":"hello"}]}`) + + if _, ok := NegotiateTransformer(body, "application/openai+json", "/v1/messages").(OpenAITransformer); !ok { + t.Fatal("expected explicit OpenAI media type to beat path/body inspection") + } +}