From b6aa33a8e0001fdc035408db06395fbd7d2a73e1 Mon Sep 17 00:00:00 2001 From: Virgil Date: Thu, 2 Apr 2026 16:58:53 +0000 Subject: [PATCH] feat(mcp): improve tool schema generation --- pkg/mcp/registry.go | 165 ++++++++++++++++++++++++++++----------- pkg/mcp/registry_test.go | 66 ++++++++++++++++ 2 files changed, 185 insertions(+), 46 deletions(-) diff --git a/pkg/mcp/registry.go b/pkg/mcp/registry.go index 91c6ccf..030dd41 100644 --- a/pkg/mcp/registry.go +++ b/pkg/mcp/registry.go @@ -5,6 +5,7 @@ package mcp import ( "context" "reflect" + "time" core "dappco.re/go/core" "github.com/modelcontextprotocol/go-sdk/mcp" @@ -80,52 +81,7 @@ func structSchema(v any) map[string]any { if t.Kind() != reflect.Struct { return nil } - if t.NumField() == 0 { - return map[string]any{"type": "object", "properties": map[string]any{}} - } - - properties := make(map[string]any) - required := make([]string, 0) - - for f := range t.Fields() { - f := f - if !f.IsExported() { - continue - } - jsonTag := f.Tag.Get("json") - if jsonTag == "-" { - continue - } - name := f.Name - isOptional := false - if jsonTag != "" { - parts := splitTag(jsonTag) - name = parts[0] - for _, p := range parts[1:] { - if p == "omitempty" { - isOptional = true - } - } - } - - prop := map[string]any{ - "type": goTypeToJSONType(f.Type), - } - properties[name] = prop - - if !isOptional { - required = append(required, name) - } - } - - schema := map[string]any{ - "type": "object", - "properties": properties, - } - if len(required) > 0 { - schema["required"] = required - } - return schema + return schemaForType(t, map[reflect.Type]bool{}) } // splitTag splits a struct tag value by commas. @@ -153,3 +109,120 @@ func goTypeToJSONType(t reflect.Type) string { return "string" } } + +func schemaForType(t reflect.Type, seen map[reflect.Type]bool) map[string]any { + if t == nil { + return nil + } + + for t.Kind() == reflect.Pointer { + t = t.Elem() + if t == nil { + return nil + } + } + + if isTimeType(t) { + return map[string]any{ + "type": "string", + "format": "date-time", + } + } + + switch t.Kind() { + case reflect.Interface: + return map[string]any{} + + case reflect.Struct: + if seen[t] { + return map[string]any{"type": "object"} + } + seen[t] = true + + properties := make(map[string]any) + required := make([]string, 0, t.NumField()) + + for f := range t.Fields() { + f := f + if !f.IsExported() { + continue + } + + jsonTag := f.Tag.Get("json") + if jsonTag == "-" { + continue + } + + name := f.Name + isOptional := false + if jsonTag != "" { + parts := splitTag(jsonTag) + name = parts[0] + for _, p := range parts[1:] { + if p == "omitempty" { + isOptional = true + } + } + } + + prop := schemaForType(f.Type, cloneSeenSet(seen)) + if prop == nil { + prop = map[string]any{"type": goTypeToJSONType(f.Type)} + } + properties[name] = prop + + if !isOptional { + required = append(required, name) + } + } + + schema := map[string]any{ + "type": "object", + "properties": properties, + } + if len(required) > 0 { + schema["required"] = required + } + return schema + + case reflect.Slice, reflect.Array: + schema := map[string]any{ + "type": "array", + "items": schemaForType(t.Elem(), cloneSeenSet(seen)), + } + return schema + + case reflect.Map: + schema := map[string]any{ + "type": "object", + } + if t.Key().Kind() == reflect.String { + if valueSchema := schemaForType(t.Elem(), cloneSeenSet(seen)); valueSchema != nil { + schema["additionalProperties"] = valueSchema + } + } + return schema + + default: + if typeName := goTypeToJSONType(t); typeName != "" { + return map[string]any{"type": typeName} + } + } + + return nil +} + +func cloneSeenSet(seen map[reflect.Type]bool) map[reflect.Type]bool { + if len(seen) == 0 { + return map[reflect.Type]bool{} + } + clone := make(map[reflect.Type]bool, len(seen)) + for t := range seen { + clone[t] = true + } + return clone +} + +func isTimeType(t reflect.Type) bool { + return t == reflect.TypeOf(time.Time{}) +} diff --git a/pkg/mcp/registry_test.go b/pkg/mcp/registry_test.go index 0686fe5..cb59bab 100644 --- a/pkg/mcp/registry_test.go +++ b/pkg/mcp/registry_test.go @@ -4,6 +4,8 @@ package mcp import ( "testing" + + "forge.lthn.ai/core/go-process" ) func TestToolRegistry_Good_RecordsTools(t *testing.T) { @@ -188,3 +190,67 @@ func TestToolRegistry_Good_ToolRecordFields(t *testing.T) { t.Error("expected non-nil OutputSchema") } } + +func TestToolRegistry_Good_TimeSchemas(t *testing.T) { + svc, err := New(Options{ + WorkspaceRoot: t.TempDir(), + ProcessService: &process.Service{}, + }) + if err != nil { + t.Fatal(err) + } + + byName := make(map[string]ToolRecord) + for _, tr := range svc.Tools() { + byName[tr.Name] = tr + } + + metrics, ok := byName["metrics_record"] + if !ok { + t.Fatal("metrics_record not found in registry") + } + inputProps, ok := metrics.InputSchema["properties"].(map[string]any) + if !ok { + t.Fatal("expected metrics_record input properties map") + } + dataSchema, ok := inputProps["data"].(map[string]any) + if !ok { + t.Fatal("expected data schema for metrics_record input") + } + if got := dataSchema["type"]; got != "object" { + t.Fatalf("expected metrics_record data type object, got %#v", got) + } + props, ok := metrics.OutputSchema["properties"].(map[string]any) + if !ok { + t.Fatal("expected metrics_record output properties map") + } + timestamp, ok := props["timestamp"].(map[string]any) + if !ok { + t.Fatal("expected timestamp schema for metrics_record output") + } + if got := timestamp["type"]; got != "string" { + t.Fatalf("expected metrics_record timestamp type string, got %#v", got) + } + if got := timestamp["format"]; got != "date-time" { + t.Fatalf("expected metrics_record timestamp format date-time, got %#v", got) + } + + processStart, ok := byName["process_start"] + if !ok { + t.Fatal("process_start not found in registry") + } + props, ok = processStart.OutputSchema["properties"].(map[string]any) + if !ok { + t.Fatal("expected process_start output properties map") + } + startedAt, ok := props["startedAt"].(map[string]any) + if !ok { + t.Fatal("expected startedAt schema for process_start output") + } + if got := startedAt["type"]; got != "string" { + t.Fatalf("expected process_start startedAt type string, got %#v", got) + } + if got := startedAt["format"]; got != "date-time" { + t.Fatalf("expected process_start startedAt format date-time, got %#v", got) + } +}