feat(session): validate agent type against spec

Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
Virgil 2026-04-02 05:17:37 +00:00
parent 2b40d0a3b0
commit 7253e1240e
3 changed files with 44 additions and 9 deletions

View file

@ -102,17 +102,17 @@ func TestCommandsSession_CmdSessionStart_Good(t *testing.T) {
var payload map[string]any
parseResult := core.JSONUnmarshalString(bodyResult.Value.(string), &payload)
require.True(t, parseResult.OK)
assert.Equal(t, "codex", payload["agent_type"])
assert.Equal(t, "opus", payload["agent_type"])
assert.Equal(t, "ax-follow-up", payload["plan_slug"])
_, _ = w.Write([]byte(`{"data":{"session_id":"ses-start","plan_slug":"ax-follow-up","agent_type":"codex","status":"active"}}`))
_, _ = w.Write([]byte(`{"data":{"session_id":"ses-start","plan_slug":"ax-follow-up","agent_type":"opus","status":"active"}}`))
}))
defer server.Close()
subsystem := testPrepWithPlatformServer(t, server, "secret-token")
result := subsystem.cmdSessionStart(core.NewOptions(
core.Option{Key: "_arg", Value: "ax-follow-up"},
core.Option{Key: "agent_type", Value: "codex"},
core.Option{Key: "agent_type", Value: "opus"},
))
require.True(t, result.OK)
@ -120,19 +120,32 @@ func TestCommandsSession_CmdSessionStart_Good(t *testing.T) {
require.True(t, ok)
assert.Equal(t, "ses-start", output.Session.SessionID)
assert.Equal(t, "ax-follow-up", output.Session.PlanSlug)
assert.Equal(t, "codex", output.Session.AgentType)
assert.Equal(t, "opus", output.Session.AgentType)
}
func TestCommandsSession_CmdSessionStart_Bad_MissingPlanSlug(t *testing.T) {
subsystem := testPrepWithPlatformServer(t, nil, "secret-token")
result := subsystem.cmdSessionStart(core.NewOptions(core.Option{Key: "agent_type", Value: "codex"}))
result := subsystem.cmdSessionStart(core.NewOptions(core.Option{Key: "agent_type", Value: "opus"}))
assert.False(t, result.OK)
require.Error(t, result.Value.(error))
assert.Contains(t, result.Value.(error).Error(), "plan_slug is required")
}
func TestCommandsSession_CmdSessionStart_Bad_InvalidAgentType(t *testing.T) {
subsystem := testPrepWithPlatformServer(t, nil, "secret-token")
result := subsystem.cmdSessionStart(core.NewOptions(
core.Option{Key: "_arg", Value: "ax-follow-up"},
core.Option{Key: "agent_type", Value: "codex"},
))
assert.False(t, result.OK)
require.Error(t, result.Value.(error))
assert.Contains(t, result.Value.(error).Error(), "opus, sonnet, or haiku")
}
func TestCommandsSession_CmdSessionStart_Ugly_InvalidResponse(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = w.Write([]byte(`{"data":`))

View file

@ -364,6 +364,9 @@ func (s *PrepSubsystem) sessionStart(ctx context.Context, _ *mcp.CallToolRequest
if input.AgentType == "" {
return nil, SessionOutput{}, core.E("sessionStart", "agent_type is required", nil)
}
if !validSessionAgentType(input.AgentType) {
return nil, SessionOutput{}, core.E("sessionStart", "agent_type must be opus, sonnet, or haiku", nil)
}
body := map[string]any{
"agent_type": input.AgentType,
@ -1137,3 +1140,12 @@ func resultErrorValue(action string, result core.Result) error {
return core.E(action, "request failed", nil)
}
func validSessionAgentType(agentType string) bool {
switch core.Lower(core.Trim(agentType)) {
case "opus", "sonnet", "haiku":
return true
default:
return false
}
}

View file

@ -25,16 +25,16 @@ func TestSession_HandleSessionStart_Good(t *testing.T) {
var payload map[string]any
parseResult := core.JSONUnmarshalString(bodyResult.Value.(string), &payload)
require.True(t, parseResult.OK)
require.Equal(t, "codex", payload["agent_type"])
require.Equal(t, "opus", payload["agent_type"])
require.Equal(t, "ax-follow-up", payload["plan_slug"])
_, _ = w.Write([]byte(`{"data":{"id":1,"session_id":"ses_abc123","plan_slug":"ax-follow-up","agent_type":"codex","status":"active","context_summary":{"repo":"core/go"}}}`))
_, _ = w.Write([]byte(`{"data":{"id":1,"session_id":"ses_abc123","plan_slug":"ax-follow-up","agent_type":"opus","status":"active","context_summary":{"repo":"core/go"}}}`))
}))
defer server.Close()
subsystem := testPrepWithPlatformServer(t, server, "secret-token")
result := subsystem.handleSessionStart(context.Background(), core.NewOptions(
core.Option{Key: "agent_type", Value: "codex"},
core.Option{Key: "agent_type", Value: "opus"},
core.Option{Key: "plan_slug", Value: "ax-follow-up"},
core.Option{Key: "context", Value: `{"repo":"core/go"}`},
))
@ -44,7 +44,7 @@ func TestSession_HandleSessionStart_Good(t *testing.T) {
require.True(t, ok)
assert.Equal(t, "ses_abc123", output.Session.SessionID)
assert.Equal(t, "active", output.Session.Status)
assert.Equal(t, "codex", output.Session.AgentType)
assert.Equal(t, "opus", output.Session.AgentType)
}
func TestSession_HandleSessionStart_Bad(t *testing.T) {
@ -54,6 +54,16 @@ func TestSession_HandleSessionStart_Bad(t *testing.T) {
assert.False(t, result.OK)
}
func TestSession_HandleSessionStart_Bad_InvalidAgentType(t *testing.T) {
subsystem := testPrepWithPlatformServer(t, nil, "secret-token")
result := subsystem.handleSessionStart(context.Background(), core.NewOptions(
core.Option{Key: "agent_type", Value: "codex"},
))
assert.False(t, result.OK)
require.Contains(t, result.Value.(error).Error(), "opus, sonnet, or haiku")
}
func TestSession_HandleSessionStart_Ugly(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = w.Write([]byte(`{"data":`))