feat(session): validate agent type against spec
Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
parent
2b40d0a3b0
commit
7253e1240e
3 changed files with 44 additions and 9 deletions
|
|
@ -102,17 +102,17 @@ func TestCommandsSession_CmdSessionStart_Good(t *testing.T) {
|
|||
var payload map[string]any
|
||||
parseResult := core.JSONUnmarshalString(bodyResult.Value.(string), &payload)
|
||||
require.True(t, parseResult.OK)
|
||||
assert.Equal(t, "codex", payload["agent_type"])
|
||||
assert.Equal(t, "opus", payload["agent_type"])
|
||||
assert.Equal(t, "ax-follow-up", payload["plan_slug"])
|
||||
|
||||
_, _ = w.Write([]byte(`{"data":{"session_id":"ses-start","plan_slug":"ax-follow-up","agent_type":"codex","status":"active"}}`))
|
||||
_, _ = w.Write([]byte(`{"data":{"session_id":"ses-start","plan_slug":"ax-follow-up","agent_type":"opus","status":"active"}}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
subsystem := testPrepWithPlatformServer(t, server, "secret-token")
|
||||
result := subsystem.cmdSessionStart(core.NewOptions(
|
||||
core.Option{Key: "_arg", Value: "ax-follow-up"},
|
||||
core.Option{Key: "agent_type", Value: "codex"},
|
||||
core.Option{Key: "agent_type", Value: "opus"},
|
||||
))
|
||||
require.True(t, result.OK)
|
||||
|
||||
|
|
@ -120,19 +120,32 @@ func TestCommandsSession_CmdSessionStart_Good(t *testing.T) {
|
|||
require.True(t, ok)
|
||||
assert.Equal(t, "ses-start", output.Session.SessionID)
|
||||
assert.Equal(t, "ax-follow-up", output.Session.PlanSlug)
|
||||
assert.Equal(t, "codex", output.Session.AgentType)
|
||||
assert.Equal(t, "opus", output.Session.AgentType)
|
||||
}
|
||||
|
||||
func TestCommandsSession_CmdSessionStart_Bad_MissingPlanSlug(t *testing.T) {
|
||||
subsystem := testPrepWithPlatformServer(t, nil, "secret-token")
|
||||
|
||||
result := subsystem.cmdSessionStart(core.NewOptions(core.Option{Key: "agent_type", Value: "codex"}))
|
||||
result := subsystem.cmdSessionStart(core.NewOptions(core.Option{Key: "agent_type", Value: "opus"}))
|
||||
|
||||
assert.False(t, result.OK)
|
||||
require.Error(t, result.Value.(error))
|
||||
assert.Contains(t, result.Value.(error).Error(), "plan_slug is required")
|
||||
}
|
||||
|
||||
func TestCommandsSession_CmdSessionStart_Bad_InvalidAgentType(t *testing.T) {
|
||||
subsystem := testPrepWithPlatformServer(t, nil, "secret-token")
|
||||
|
||||
result := subsystem.cmdSessionStart(core.NewOptions(
|
||||
core.Option{Key: "_arg", Value: "ax-follow-up"},
|
||||
core.Option{Key: "agent_type", Value: "codex"},
|
||||
))
|
||||
|
||||
assert.False(t, result.OK)
|
||||
require.Error(t, result.Value.(error))
|
||||
assert.Contains(t, result.Value.(error).Error(), "opus, sonnet, or haiku")
|
||||
}
|
||||
|
||||
func TestCommandsSession_CmdSessionStart_Ugly_InvalidResponse(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, _ = w.Write([]byte(`{"data":`))
|
||||
|
|
|
|||
|
|
@ -364,6 +364,9 @@ func (s *PrepSubsystem) sessionStart(ctx context.Context, _ *mcp.CallToolRequest
|
|||
if input.AgentType == "" {
|
||||
return nil, SessionOutput{}, core.E("sessionStart", "agent_type is required", nil)
|
||||
}
|
||||
if !validSessionAgentType(input.AgentType) {
|
||||
return nil, SessionOutput{}, core.E("sessionStart", "agent_type must be opus, sonnet, or haiku", nil)
|
||||
}
|
||||
|
||||
body := map[string]any{
|
||||
"agent_type": input.AgentType,
|
||||
|
|
@ -1137,3 +1140,12 @@ func resultErrorValue(action string, result core.Result) error {
|
|||
|
||||
return core.E(action, "request failed", nil)
|
||||
}
|
||||
|
||||
func validSessionAgentType(agentType string) bool {
|
||||
switch core.Lower(core.Trim(agentType)) {
|
||||
case "opus", "sonnet", "haiku":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -25,16 +25,16 @@ func TestSession_HandleSessionStart_Good(t *testing.T) {
|
|||
var payload map[string]any
|
||||
parseResult := core.JSONUnmarshalString(bodyResult.Value.(string), &payload)
|
||||
require.True(t, parseResult.OK)
|
||||
require.Equal(t, "codex", payload["agent_type"])
|
||||
require.Equal(t, "opus", payload["agent_type"])
|
||||
require.Equal(t, "ax-follow-up", payload["plan_slug"])
|
||||
|
||||
_, _ = w.Write([]byte(`{"data":{"id":1,"session_id":"ses_abc123","plan_slug":"ax-follow-up","agent_type":"codex","status":"active","context_summary":{"repo":"core/go"}}}`))
|
||||
_, _ = w.Write([]byte(`{"data":{"id":1,"session_id":"ses_abc123","plan_slug":"ax-follow-up","agent_type":"opus","status":"active","context_summary":{"repo":"core/go"}}}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
subsystem := testPrepWithPlatformServer(t, server, "secret-token")
|
||||
result := subsystem.handleSessionStart(context.Background(), core.NewOptions(
|
||||
core.Option{Key: "agent_type", Value: "codex"},
|
||||
core.Option{Key: "agent_type", Value: "opus"},
|
||||
core.Option{Key: "plan_slug", Value: "ax-follow-up"},
|
||||
core.Option{Key: "context", Value: `{"repo":"core/go"}`},
|
||||
))
|
||||
|
|
@ -44,7 +44,7 @@ func TestSession_HandleSessionStart_Good(t *testing.T) {
|
|||
require.True(t, ok)
|
||||
assert.Equal(t, "ses_abc123", output.Session.SessionID)
|
||||
assert.Equal(t, "active", output.Session.Status)
|
||||
assert.Equal(t, "codex", output.Session.AgentType)
|
||||
assert.Equal(t, "opus", output.Session.AgentType)
|
||||
}
|
||||
|
||||
func TestSession_HandleSessionStart_Bad(t *testing.T) {
|
||||
|
|
@ -54,6 +54,16 @@ func TestSession_HandleSessionStart_Bad(t *testing.T) {
|
|||
assert.False(t, result.OK)
|
||||
}
|
||||
|
||||
func TestSession_HandleSessionStart_Bad_InvalidAgentType(t *testing.T) {
|
||||
subsystem := testPrepWithPlatformServer(t, nil, "secret-token")
|
||||
|
||||
result := subsystem.handleSessionStart(context.Background(), core.NewOptions(
|
||||
core.Option{Key: "agent_type", Value: "codex"},
|
||||
))
|
||||
assert.False(t, result.OK)
|
||||
require.Contains(t, result.Value.(error).Error(), "opus, sonnet, or haiku")
|
||||
}
|
||||
|
||||
func TestSession_HandleSessionStart_Ugly(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, _ = w.Write([]byte(`{"data":`))
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue