feat(mcp): improve tool schema generation
This commit is contained in:
parent
c83df5f113
commit
b6aa33a8e0
2 changed files with 185 additions and 46 deletions
|
|
@ -5,6 +5,7 @@ package mcp
|
|||
import (
|
||||
"context"
|
||||
"reflect"
|
||||
"time"
|
||||
|
||||
core "dappco.re/go/core"
|
||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
|
|
@ -80,52 +81,7 @@ func structSchema(v any) map[string]any {
|
|||
if t.Kind() != reflect.Struct {
|
||||
return nil
|
||||
}
|
||||
if t.NumField() == 0 {
|
||||
return map[string]any{"type": "object", "properties": map[string]any{}}
|
||||
}
|
||||
|
||||
properties := make(map[string]any)
|
||||
required := make([]string, 0)
|
||||
|
||||
for f := range t.Fields() {
|
||||
f := f
|
||||
if !f.IsExported() {
|
||||
continue
|
||||
}
|
||||
jsonTag := f.Tag.Get("json")
|
||||
if jsonTag == "-" {
|
||||
continue
|
||||
}
|
||||
name := f.Name
|
||||
isOptional := false
|
||||
if jsonTag != "" {
|
||||
parts := splitTag(jsonTag)
|
||||
name = parts[0]
|
||||
for _, p := range parts[1:] {
|
||||
if p == "omitempty" {
|
||||
isOptional = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
prop := map[string]any{
|
||||
"type": goTypeToJSONType(f.Type),
|
||||
}
|
||||
properties[name] = prop
|
||||
|
||||
if !isOptional {
|
||||
required = append(required, name)
|
||||
}
|
||||
}
|
||||
|
||||
schema := map[string]any{
|
||||
"type": "object",
|
||||
"properties": properties,
|
||||
}
|
||||
if len(required) > 0 {
|
||||
schema["required"] = required
|
||||
}
|
||||
return schema
|
||||
return schemaForType(t, map[reflect.Type]bool{})
|
||||
}
|
||||
|
||||
// splitTag splits a struct tag value by commas.
|
||||
|
|
@ -153,3 +109,120 @@ func goTypeToJSONType(t reflect.Type) string {
|
|||
return "string"
|
||||
}
|
||||
}
|
||||
|
||||
func schemaForType(t reflect.Type, seen map[reflect.Type]bool) map[string]any {
|
||||
if t == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
for t.Kind() == reflect.Pointer {
|
||||
t = t.Elem()
|
||||
if t == nil {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
if isTimeType(t) {
|
||||
return map[string]any{
|
||||
"type": "string",
|
||||
"format": "date-time",
|
||||
}
|
||||
}
|
||||
|
||||
switch t.Kind() {
|
||||
case reflect.Interface:
|
||||
return map[string]any{}
|
||||
|
||||
case reflect.Struct:
|
||||
if seen[t] {
|
||||
return map[string]any{"type": "object"}
|
||||
}
|
||||
seen[t] = true
|
||||
|
||||
properties := make(map[string]any)
|
||||
required := make([]string, 0, t.NumField())
|
||||
|
||||
for f := range t.Fields() {
|
||||
f := f
|
||||
if !f.IsExported() {
|
||||
continue
|
||||
}
|
||||
|
||||
jsonTag := f.Tag.Get("json")
|
||||
if jsonTag == "-" {
|
||||
continue
|
||||
}
|
||||
|
||||
name := f.Name
|
||||
isOptional := false
|
||||
if jsonTag != "" {
|
||||
parts := splitTag(jsonTag)
|
||||
name = parts[0]
|
||||
for _, p := range parts[1:] {
|
||||
if p == "omitempty" {
|
||||
isOptional = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
prop := schemaForType(f.Type, cloneSeenSet(seen))
|
||||
if prop == nil {
|
||||
prop = map[string]any{"type": goTypeToJSONType(f.Type)}
|
||||
}
|
||||
properties[name] = prop
|
||||
|
||||
if !isOptional {
|
||||
required = append(required, name)
|
||||
}
|
||||
}
|
||||
|
||||
schema := map[string]any{
|
||||
"type": "object",
|
||||
"properties": properties,
|
||||
}
|
||||
if len(required) > 0 {
|
||||
schema["required"] = required
|
||||
}
|
||||
return schema
|
||||
|
||||
case reflect.Slice, reflect.Array:
|
||||
schema := map[string]any{
|
||||
"type": "array",
|
||||
"items": schemaForType(t.Elem(), cloneSeenSet(seen)),
|
||||
}
|
||||
return schema
|
||||
|
||||
case reflect.Map:
|
||||
schema := map[string]any{
|
||||
"type": "object",
|
||||
}
|
||||
if t.Key().Kind() == reflect.String {
|
||||
if valueSchema := schemaForType(t.Elem(), cloneSeenSet(seen)); valueSchema != nil {
|
||||
schema["additionalProperties"] = valueSchema
|
||||
}
|
||||
}
|
||||
return schema
|
||||
|
||||
default:
|
||||
if typeName := goTypeToJSONType(t); typeName != "" {
|
||||
return map[string]any{"type": typeName}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func cloneSeenSet(seen map[reflect.Type]bool) map[reflect.Type]bool {
|
||||
if len(seen) == 0 {
|
||||
return map[reflect.Type]bool{}
|
||||
}
|
||||
clone := make(map[reflect.Type]bool, len(seen))
|
||||
for t := range seen {
|
||||
clone[t] = true
|
||||
}
|
||||
return clone
|
||||
}
|
||||
|
||||
func isTimeType(t reflect.Type) bool {
|
||||
return t == reflect.TypeOf(time.Time{})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,6 +4,8 @@ package mcp
|
|||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"forge.lthn.ai/core/go-process"
|
||||
)
|
||||
|
||||
func TestToolRegistry_Good_RecordsTools(t *testing.T) {
|
||||
|
|
@ -188,3 +190,67 @@ func TestToolRegistry_Good_ToolRecordFields(t *testing.T) {
|
|||
t.Error("expected non-nil OutputSchema")
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolRegistry_Good_TimeSchemas(t *testing.T) {
|
||||
svc, err := New(Options{
|
||||
WorkspaceRoot: t.TempDir(),
|
||||
ProcessService: &process.Service{},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
byName := make(map[string]ToolRecord)
|
||||
for _, tr := range svc.Tools() {
|
||||
byName[tr.Name] = tr
|
||||
}
|
||||
|
||||
metrics, ok := byName["metrics_record"]
|
||||
if !ok {
|
||||
t.Fatal("metrics_record not found in registry")
|
||||
}
|
||||
inputProps, ok := metrics.InputSchema["properties"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("expected metrics_record input properties map")
|
||||
}
|
||||
dataSchema, ok := inputProps["data"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("expected data schema for metrics_record input")
|
||||
}
|
||||
if got := dataSchema["type"]; got != "object" {
|
||||
t.Fatalf("expected metrics_record data type object, got %#v", got)
|
||||
}
|
||||
props, ok := metrics.OutputSchema["properties"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("expected metrics_record output properties map")
|
||||
}
|
||||
timestamp, ok := props["timestamp"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("expected timestamp schema for metrics_record output")
|
||||
}
|
||||
if got := timestamp["type"]; got != "string" {
|
||||
t.Fatalf("expected metrics_record timestamp type string, got %#v", got)
|
||||
}
|
||||
if got := timestamp["format"]; got != "date-time" {
|
||||
t.Fatalf("expected metrics_record timestamp format date-time, got %#v", got)
|
||||
}
|
||||
|
||||
processStart, ok := byName["process_start"]
|
||||
if !ok {
|
||||
t.Fatal("process_start not found in registry")
|
||||
}
|
||||
props, ok = processStart.OutputSchema["properties"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("expected process_start output properties map")
|
||||
}
|
||||
startedAt, ok := props["startedAt"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("expected startedAt schema for process_start output")
|
||||
}
|
||||
if got := startedAt["type"]; got != "string" {
|
||||
t.Fatalf("expected process_start startedAt type string, got %#v", got)
|
||||
}
|
||||
if got := startedAt["format"]; got != "date-time" {
|
||||
t.Fatalf("expected process_start startedAt format date-time, got %#v", got)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue