From 4ff21338ee54f44f8dad7f23a2e0b1d408b56fa1 Mon Sep 17 00:00:00 2001 From: Virgil Date: Wed, 1 Apr 2026 16:38:21 +0000 Subject: [PATCH] feat(agentic): expose PR close as MCP tool Co-Authored-By: Virgil --- pkg/agentic/commands_forge.go | 13 +-- pkg/agentic/pr.go | 61 ++++++++++++ pkg/agentic/pr_test.go | 179 ++++++++++++++++++++-------------- pkg/agentic/prep.go | 1 + 4 files changed, 177 insertions(+), 77 deletions(-) diff --git a/pkg/agentic/commands_forge.go b/pkg/agentic/commands_forge.go index c6ca558..0ac9bce 100644 --- a/pkg/agentic/commands_forge.go +++ b/pkg/agentic/commands_forge.go @@ -371,17 +371,18 @@ func (s *PrepSubsystem) cmdPRClose(options core.Options) core.Result { return core.Result{Value: core.E("agentic.cmdPRClose", "repo and number are required", nil), OK: false} } - var pr pullRequestView - err := s.forge.Client().Patch(ctx, core.Sprintf("/api/v1/repos/%s/%s/pulls/%d", org, repo, num), &forge_types.EditPullRequestOption{ - State: "closed", - }, &pr) + _, output, err := s.closePR(ctx, nil, ClosePRInput{ + Org: org, + Repo: repo, + Number: int(num), + }) if err != nil { core.Print(nil, "error: %v", err) return core.Result{Value: err, OK: false} } - core.Print(nil, "closed %s/%s#%d", org, repo, num) - return core.Result{OK: true} + core.Print(nil, "closed %s/%s#%d", output.Org, output.Repo, output.Number) + return core.Result{Value: output, OK: true} } func (s *PrepSubsystem) cmdRepoGet(options core.Options) core.Result { diff --git a/pkg/agentic/pr.go b/pkg/agentic/pr.go index 640e560..93e0387 100644 --- a/pkg/agentic/pr.go +++ b/pkg/agentic/pr.go @@ -181,6 +181,22 @@ type ListPRsOutput struct { PRs []PRInfo `json:"prs"` } +// input := agentic.ClosePRInput{Org: "core", Repo: "go-io", Number: 12} +type ClosePRInput struct { + Org string `json:"org,omitempty"` + Repo string `json:"repo"` + Number int `json:"number"` +} + +// out := agentic.ClosePROutput{Success: true, Repo: "go-io", Number: 12, State: "closed"} +type ClosePROutput struct { + Success bool `json:"success"` + Org string `json:"org,omitempty"` + Repo string `json:"repo"` + Number int `json:"number"` + State string `json:"state,omitempty"` +} + // pr := agentic.PRInfo{Repo: "go-io", Number: 12, Title: "Migrate pkg/fs", Branch: "agent/migrate-fs"} type PRInfo struct { Repo string `json:"repo"` @@ -202,6 +218,13 @@ func (s *PrepSubsystem) registerListPRsTool(server *mcp.Server) { }, s.listPRs) } +func (s *PrepSubsystem) registerClosePRTool(server *mcp.Server) { + mcp.AddTool(server, &mcp.Tool{ + Name: "agentic_close_pr", + Description: "Close a pull request on Forge by repository and pull request number.", + }, s.closePR) +} + func (s *PrepSubsystem) listPRs(ctx context.Context, _ *mcp.CallToolRequest, input ListPRsInput) (*mcp.CallToolResult, ListPRsOutput, error) { if s.forgeToken == "" { return nil, ListPRsOutput{}, core.E("listPRs", "no Forge token configured", nil) @@ -253,6 +276,44 @@ func (s *PrepSubsystem) listPRs(ctx context.Context, _ *mcp.CallToolRequest, inp }, nil } +func (s *PrepSubsystem) closePR(ctx context.Context, _ *mcp.CallToolRequest, input ClosePRInput) (*mcp.CallToolResult, ClosePROutput, error) { + if s.forgeToken == "" { + return nil, ClosePROutput{}, core.E("closePR", "no Forge token configured", nil) + } + if s.forge == nil { + return nil, ClosePROutput{}, core.E("closePR", "forge client is not configured", nil) + } + if input.Repo == "" || input.Number <= 0 { + return nil, ClosePROutput{}, core.E("closePR", "repo and number are required", nil) + } + + org := input.Org + if org == "" { + org = "core" + } + + var pr pullRequestView + err := s.forge.Client().Patch(ctx, core.Sprintf("/api/v1/repos/%s/%s/pulls/%d", org, input.Repo, input.Number), &forge_types.EditPullRequestOption{ + State: "closed", + }, &pr) + if err != nil { + return nil, ClosePROutput{}, core.E("closePR", core.Concat("failed to close PR ", core.Sprint(input.Number)), err) + } + + state := pr.State + if state == "" { + state = "closed" + } + + return nil, ClosePROutput{ + Success: true, + Org: org, + Repo: input.Repo, + Number: input.Number, + State: state, + }, nil +} + func (s *PrepSubsystem) listRepoPRs(ctx context.Context, org, repo, state string) ([]PRInfo, error) { var pullRequests []pullRequestView err := s.forge.Client().Get(ctx, core.Sprintf("/api/v1/repos/%s/%s/pulls?limit=50&page=1", org, repo), &pullRequests) diff --git a/pkg/agentic/pr_test.go b/pkg/agentic/pr_test.go index 558864d..81bf1c9 100644 --- a/pkg/agentic/pr_test.go +++ b/pkg/agentic/pr_test.go @@ -63,11 +63,11 @@ func TestPr_ForgeCreatePR_Good_Success(t *testing.T) { srv := mockPRForgeServer(t) s := &PrepSubsystem{ ServiceRuntime: core.NewServiceRuntime(testCore, AgentOptions{}), - forge: forge.NewForge(srv.URL, "test-token"), - forgeURL: srv.URL, - forgeToken: "test-token", - backoff: make(map[string]time.Time), - failCount: make(map[string]int), + forge: forge.NewForge(srv.URL, "test-token"), + forgeURL: srv.URL, + forgeToken: "test-token", + backoff: make(map[string]time.Time), + failCount: make(map[string]int), } prURL, prNum, err := s.forgeCreatePR( @@ -90,11 +90,11 @@ func TestPr_ForgeCreatePR_Bad_ServerError(t *testing.T) { s := &PrepSubsystem{ ServiceRuntime: core.NewServiceRuntime(testCore, AgentOptions{}), - forge: forge.NewForge(srv.URL, "test-token"), - forgeURL: srv.URL, - forgeToken: "test-token", - backoff: make(map[string]time.Time), - failCount: make(map[string]int), + forge: forge.NewForge(srv.URL, "test-token"), + forgeURL: srv.URL, + forgeToken: "test-token", + backoff: make(map[string]time.Time), + failCount: make(map[string]int), } _, _, err := s.forgeCreatePR( @@ -111,9 +111,9 @@ func TestPr_ForgeCreatePR_Bad_ServerError(t *testing.T) { func TestPr_CreatePR_Bad_NoWorkspace(t *testing.T) { s := &PrepSubsystem{ ServiceRuntime: core.NewServiceRuntime(testCore, AgentOptions{}), - forgeToken: "test-token", - backoff: make(map[string]time.Time), - failCount: make(map[string]int), + forgeToken: "test-token", + backoff: make(map[string]time.Time), + failCount: make(map[string]int), } _, _, err := s.createPR(context.Background(), nil, CreatePRInput{}) @@ -124,9 +124,9 @@ func TestPr_CreatePR_Bad_NoWorkspace(t *testing.T) { func TestPr_CreatePR_Bad_NoToken(t *testing.T) { s := &PrepSubsystem{ ServiceRuntime: core.NewServiceRuntime(testCore, AgentOptions{}), - forgeToken: "", - backoff: make(map[string]time.Time), - failCount: make(map[string]int), + forgeToken: "", + backoff: make(map[string]time.Time), + failCount: make(map[string]int), } _, _, err := s.createPR(context.Background(), nil, CreatePRInput{ @@ -142,9 +142,9 @@ func TestPr_CreatePR_Bad_WorkspaceNotFound(t *testing.T) { s := &PrepSubsystem{ ServiceRuntime: core.NewServiceRuntime(testCore, AgentOptions{}), - forgeToken: "test-token", - backoff: make(map[string]time.Time), - failCount: make(map[string]int), + forgeToken: "test-token", + backoff: make(map[string]time.Time), + failCount: make(map[string]int), } _, _, err := s.createPR(context.Background(), nil, CreatePRInput{ @@ -174,9 +174,9 @@ func TestPr_CreatePR_Good_DryRun(t *testing.T) { s := &PrepSubsystem{ ServiceRuntime: core.NewServiceRuntime(testCore, AgentOptions{}), - forgeToken: "test-token", - backoff: make(map[string]time.Time), - failCount: make(map[string]int), + forgeToken: "test-token", + backoff: make(map[string]time.Time), + failCount: make(map[string]int), } _, out, err := s.createPR(context.Background(), nil, CreatePRInput{ @@ -209,9 +209,9 @@ func TestPr_CreatePR_Good_CustomTitle(t *testing.T) { s := &PrepSubsystem{ ServiceRuntime: core.NewServiceRuntime(testCore, AgentOptions{}), - forgeToken: "test-token", - backoff: make(map[string]time.Time), - failCount: make(map[string]int), + forgeToken: "test-token", + backoff: make(map[string]time.Time), + failCount: make(map[string]int), } _, out, err := s.createPR(context.Background(), nil, CreatePRInput{ @@ -223,14 +223,51 @@ func TestPr_CreatePR_Good_CustomTitle(t *testing.T) { assert.Equal(t, "Custom PR title", out.Title) } +func TestPr_ClosePR_Good_Success(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, http.MethodPatch, r.Method) + assert.Equal(t, "/api/v1/repos/core/test-repo/pulls/7", r.URL.Path) + + bodyResult := core.ReadAll(r.Body) + assert.True(t, bodyResult.OK) + assert.Contains(t, bodyResult.Value.(string), `"state":"closed"`) + + w.Write([]byte(core.JSONMarshalString(map[string]any{ + "number": 7, + "state": "closed", + }))) + })) + t.Cleanup(srv.Close) + + s := &PrepSubsystem{ + ServiceRuntime: core.NewServiceRuntime(testCore, AgentOptions{}), + forge: forge.NewForge(srv.URL, "test-token"), + forgeURL: srv.URL, + forgeToken: "test-token", + backoff: make(map[string]time.Time), + failCount: make(map[string]int), + } + + _, out, err := s.closePR(context.Background(), nil, ClosePRInput{ + Repo: "test-repo", + Number: 7, + }) + require.NoError(t, err) + assert.True(t, out.Success) + assert.Equal(t, "core", out.Org) + assert.Equal(t, "test-repo", out.Repo) + assert.Equal(t, 7, out.Number) + assert.Equal(t, "closed", out.State) +} + // --- listPRs --- func TestPr_ListPRs_Bad_NoToken(t *testing.T) { s := &PrepSubsystem{ ServiceRuntime: core.NewServiceRuntime(testCore, AgentOptions{}), - forgeToken: "", - backoff: make(map[string]time.Time), - failCount: make(map[string]int), + forgeToken: "", + backoff: make(map[string]time.Time), + failCount: make(map[string]int), } _, _, err := s.listPRs(context.Background(), nil, ListPRsInput{}) @@ -252,11 +289,11 @@ func TestPr_CommentOnIssue_Good_PostsComment(t *testing.T) { s := &PrepSubsystem{ ServiceRuntime: core.NewServiceRuntime(testCore, AgentOptions{}), - forge: forge.NewForge(srv.URL, "test-token"), - forgeURL: srv.URL, - forgeToken: "test-token", - backoff: make(map[string]time.Time), - failCount: make(map[string]int), + forge: forge.NewForge(srv.URL, "test-token"), + forgeURL: srv.URL, + forgeToken: "test-token", + backoff: make(map[string]time.Time), + failCount: make(map[string]int), } s.commentOnIssue(context.Background(), "core", "go-io", 42, "Test comment") @@ -319,11 +356,11 @@ func TestPr_CommentOnIssue_Bad(t *testing.T) { s := &PrepSubsystem{ ServiceRuntime: core.NewServiceRuntime(testCore, AgentOptions{}), - forge: forge.NewForge(srv.URL, "test-token"), - forgeURL: srv.URL, - forgeToken: "test-token", - backoff: make(map[string]time.Time), - failCount: make(map[string]int), + forge: forge.NewForge(srv.URL, "test-token"), + forgeURL: srv.URL, + forgeToken: "test-token", + backoff: make(map[string]time.Time), + failCount: make(map[string]int), } // Should not panic even on server error @@ -345,11 +382,11 @@ func TestPr_CommentOnIssue_Ugly(t *testing.T) { s := &PrepSubsystem{ ServiceRuntime: core.NewServiceRuntime(testCore, AgentOptions{}), - forge: forge.NewForge(srv.URL, "test-token"), - forgeURL: srv.URL, - forgeToken: "test-token", - backoff: make(map[string]time.Time), - failCount: make(map[string]int), + forge: forge.NewForge(srv.URL, "test-token"), + forgeURL: srv.URL, + forgeToken: "test-token", + backoff: make(map[string]time.Time), + failCount: make(map[string]int), } longComment := strings.Repeat("This is a very long comment with details. ", 1000) @@ -385,9 +422,9 @@ func TestPr_CreatePR_Ugly(t *testing.T) { s := &PrepSubsystem{ ServiceRuntime: core.NewServiceRuntime(testCore, AgentOptions{}), - forgeToken: "test-token", - backoff: make(map[string]time.Time), - failCount: make(map[string]int), + forgeToken: "test-token", + backoff: make(map[string]time.Time), + failCount: make(map[string]int), } _, out, err := s.createPR(context.Background(), nil, CreatePRInput{ @@ -418,11 +455,11 @@ func TestPr_ForgeCreatePR_Ugly(t *testing.T) { s := &PrepSubsystem{ ServiceRuntime: core.NewServiceRuntime(testCore, AgentOptions{}), - forge: forge.NewForge(srv.URL, "test-token"), - forgeURL: srv.URL, - forgeToken: "test-token", - backoff: make(map[string]time.Time), - failCount: make(map[string]int), + forge: forge.NewForge(srv.URL, "test-token"), + forgeURL: srv.URL, + forgeToken: "test-token", + backoff: make(map[string]time.Time), + failCount: make(map[string]int), } // Should not panic — may return zero values for missing fields @@ -453,11 +490,11 @@ func TestPr_ListPRs_Ugly(t *testing.T) { s := &PrepSubsystem{ ServiceRuntime: core.NewServiceRuntime(testCore, AgentOptions{}), - forge: forge.NewForge(srv.URL, "test-token"), - forgeURL: srv.URL, - forgeToken: "test-token", - backoff: make(map[string]time.Time), - failCount: make(map[string]int), + forge: forge.NewForge(srv.URL, "test-token"), + forgeURL: srv.URL, + forgeToken: "test-token", + backoff: make(map[string]time.Time), + failCount: make(map[string]int), } _, out, err := s.listPRs(context.Background(), nil, ListPRsInput{ @@ -474,11 +511,11 @@ func TestPr_ListRepoPRs_Good(t *testing.T) { srv := mockPRForgeServer(t) s := &PrepSubsystem{ ServiceRuntime: core.NewServiceRuntime(testCore, AgentOptions{}), - forge: forge.NewForge(srv.URL, "test-token"), - forgeURL: srv.URL, - forgeToken: "test-token", - backoff: make(map[string]time.Time), - failCount: make(map[string]int), + forge: forge.NewForge(srv.URL, "test-token"), + forgeURL: srv.URL, + forgeToken: "test-token", + backoff: make(map[string]time.Time), + failCount: make(map[string]int), } prs, err := s.listRepoPRs(context.Background(), "core", "test-repo", "open") @@ -496,11 +533,11 @@ func TestPr_ListRepoPRs_Bad(t *testing.T) { s := &PrepSubsystem{ ServiceRuntime: core.NewServiceRuntime(testCore, AgentOptions{}), - forge: forge.NewForge(srv.URL, "test-token"), - forgeURL: srv.URL, - forgeToken: "test-token", - backoff: make(map[string]time.Time), - failCount: make(map[string]int), + forge: forge.NewForge(srv.URL, "test-token"), + forgeURL: srv.URL, + forgeToken: "test-token", + backoff: make(map[string]time.Time), + failCount: make(map[string]int), } _, err := s.listRepoPRs(context.Background(), "core", "go-io", "open") @@ -516,11 +553,11 @@ func TestPr_ListRepoPRs_Ugly(t *testing.T) { s := &PrepSubsystem{ ServiceRuntime: core.NewServiceRuntime(testCore, AgentOptions{}), - forge: forge.NewForge(srv.URL, "test-token"), - forgeURL: srv.URL, - forgeToken: "test-token", - backoff: make(map[string]time.Time), - failCount: make(map[string]int), + forge: forge.NewForge(srv.URL, "test-token"), + forgeURL: srv.URL, + forgeToken: "test-token", + backoff: make(map[string]time.Time), + failCount: make(map[string]int), } prs, err := s.listRepoPRs(context.Background(), "core", "empty-repo", "open") diff --git a/pkg/agentic/prep.go b/pkg/agentic/prep.go index 0cfda95..fba1609 100644 --- a/pkg/agentic/prep.go +++ b/pkg/agentic/prep.go @@ -348,6 +348,7 @@ func (s *PrepSubsystem) RegisterTools(server *mcp.Server) { s.registerResumeTool(server) s.registerCreatePRTool(server) s.registerListPRsTool(server) + s.registerClosePRTool(server) s.registerEpicTool(server) s.registerMirrorTool(server) s.registerRemoteDispatchTool(server)