diff --git a/pkg/agentic/commands_session_test.go b/pkg/agentic/commands_session_test.go index e296a28..c7deae5 100644 --- a/pkg/agentic/commands_session_test.go +++ b/pkg/agentic/commands_session_test.go @@ -102,17 +102,17 @@ func TestCommandsSession_CmdSessionStart_Good(t *testing.T) { var payload map[string]any parseResult := core.JSONUnmarshalString(bodyResult.Value.(string), &payload) require.True(t, parseResult.OK) - assert.Equal(t, "codex", payload["agent_type"]) + assert.Equal(t, "opus", payload["agent_type"]) assert.Equal(t, "ax-follow-up", payload["plan_slug"]) - _, _ = w.Write([]byte(`{"data":{"session_id":"ses-start","plan_slug":"ax-follow-up","agent_type":"codex","status":"active"}}`)) + _, _ = w.Write([]byte(`{"data":{"session_id":"ses-start","plan_slug":"ax-follow-up","agent_type":"opus","status":"active"}}`)) })) defer server.Close() subsystem := testPrepWithPlatformServer(t, server, "secret-token") result := subsystem.cmdSessionStart(core.NewOptions( core.Option{Key: "_arg", Value: "ax-follow-up"}, - core.Option{Key: "agent_type", Value: "codex"}, + core.Option{Key: "agent_type", Value: "opus"}, )) require.True(t, result.OK) @@ -120,19 +120,32 @@ func TestCommandsSession_CmdSessionStart_Good(t *testing.T) { require.True(t, ok) assert.Equal(t, "ses-start", output.Session.SessionID) assert.Equal(t, "ax-follow-up", output.Session.PlanSlug) - assert.Equal(t, "codex", output.Session.AgentType) + assert.Equal(t, "opus", output.Session.AgentType) } func TestCommandsSession_CmdSessionStart_Bad_MissingPlanSlug(t *testing.T) { subsystem := testPrepWithPlatformServer(t, nil, "secret-token") - result := subsystem.cmdSessionStart(core.NewOptions(core.Option{Key: "agent_type", Value: "codex"})) + result := subsystem.cmdSessionStart(core.NewOptions(core.Option{Key: "agent_type", Value: "opus"})) assert.False(t, result.OK) require.Error(t, result.Value.(error)) assert.Contains(t, result.Value.(error).Error(), "plan_slug is required") } +func TestCommandsSession_CmdSessionStart_Bad_InvalidAgentType(t *testing.T) { + subsystem := testPrepWithPlatformServer(t, nil, "secret-token") + + result := subsystem.cmdSessionStart(core.NewOptions( + core.Option{Key: "_arg", Value: "ax-follow-up"}, + core.Option{Key: "agent_type", Value: "codex"}, + )) + + assert.False(t, result.OK) + require.Error(t, result.Value.(error)) + assert.Contains(t, result.Value.(error).Error(), "opus, sonnet, or haiku") +} + func TestCommandsSession_CmdSessionStart_Ugly_InvalidResponse(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { _, _ = w.Write([]byte(`{"data":`)) diff --git a/pkg/agentic/session.go b/pkg/agentic/session.go index daa3774..4b690ab 100644 --- a/pkg/agentic/session.go +++ b/pkg/agentic/session.go @@ -364,6 +364,9 @@ func (s *PrepSubsystem) sessionStart(ctx context.Context, _ *mcp.CallToolRequest if input.AgentType == "" { return nil, SessionOutput{}, core.E("sessionStart", "agent_type is required", nil) } + if !validSessionAgentType(input.AgentType) { + return nil, SessionOutput{}, core.E("sessionStart", "agent_type must be opus, sonnet, or haiku", nil) + } body := map[string]any{ "agent_type": input.AgentType, @@ -1137,3 +1140,12 @@ func resultErrorValue(action string, result core.Result) error { return core.E(action, "request failed", nil) } + +func validSessionAgentType(agentType string) bool { + switch core.Lower(core.Trim(agentType)) { + case "opus", "sonnet", "haiku": + return true + default: + return false + } +} diff --git a/pkg/agentic/session_test.go b/pkg/agentic/session_test.go index 7719595..feb4d2b 100644 --- a/pkg/agentic/session_test.go +++ b/pkg/agentic/session_test.go @@ -25,16 +25,16 @@ func TestSession_HandleSessionStart_Good(t *testing.T) { var payload map[string]any parseResult := core.JSONUnmarshalString(bodyResult.Value.(string), &payload) require.True(t, parseResult.OK) - require.Equal(t, "codex", payload["agent_type"]) + require.Equal(t, "opus", payload["agent_type"]) require.Equal(t, "ax-follow-up", payload["plan_slug"]) - _, _ = w.Write([]byte(`{"data":{"id":1,"session_id":"ses_abc123","plan_slug":"ax-follow-up","agent_type":"codex","status":"active","context_summary":{"repo":"core/go"}}}`)) + _, _ = w.Write([]byte(`{"data":{"id":1,"session_id":"ses_abc123","plan_slug":"ax-follow-up","agent_type":"opus","status":"active","context_summary":{"repo":"core/go"}}}`)) })) defer server.Close() subsystem := testPrepWithPlatformServer(t, server, "secret-token") result := subsystem.handleSessionStart(context.Background(), core.NewOptions( - core.Option{Key: "agent_type", Value: "codex"}, + core.Option{Key: "agent_type", Value: "opus"}, core.Option{Key: "plan_slug", Value: "ax-follow-up"}, core.Option{Key: "context", Value: `{"repo":"core/go"}`}, )) @@ -44,7 +44,7 @@ func TestSession_HandleSessionStart_Good(t *testing.T) { require.True(t, ok) assert.Equal(t, "ses_abc123", output.Session.SessionID) assert.Equal(t, "active", output.Session.Status) - assert.Equal(t, "codex", output.Session.AgentType) + assert.Equal(t, "opus", output.Session.AgentType) } func TestSession_HandleSessionStart_Bad(t *testing.T) { @@ -54,6 +54,16 @@ func TestSession_HandleSessionStart_Bad(t *testing.T) { assert.False(t, result.OK) } +func TestSession_HandleSessionStart_Bad_InvalidAgentType(t *testing.T) { + subsystem := testPrepWithPlatformServer(t, nil, "secret-token") + + result := subsystem.handleSessionStart(context.Background(), core.NewOptions( + core.Option{Key: "agent_type", Value: "codex"}, + )) + assert.False(t, result.OK) + require.Contains(t, result.Value.(error).Error(), "opus, sonnet, or haiku") +} + func TestSession_HandleSessionStart_Ugly(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { _, _ = w.Write([]byte(`{"data":`))