commit 3e883f69769c4be7401819b2ae3e2122324a666c Author: Claude Date: Mon Feb 16 15:25:58 2026 +0000 feat: extract SCM/forge integration packages from core/go Forgejo and Gitea SDK wrappers, multi-repo git utilities, AgentCI dispatch, distributed job orchestrator, and data collection pipelines. Co-Authored-By: Claude Opus 4.6 diff --git a/agentci/clotho.go b/agentci/clotho.go new file mode 100644 index 0000000..2bec8ee --- /dev/null +++ b/agentci/clotho.go @@ -0,0 +1,87 @@ +package agentci + +import ( + "context" + "strings" + + "forge.lthn.ai/core/go-scm/jobrunner" +) + +// RunMode determines the execution strategy for a dispatched task. +type RunMode string + +const ( + ModeStandard RunMode = "standard" + ModeDual RunMode = "dual" // The Clotho Protocol — dual-run verification +) + +// Spinner is the Clotho orchestrator that determines the fate of each task. +type Spinner struct { + Config ClothoConfig + Agents map[string]AgentConfig +} + +// NewSpinner creates a new Clotho orchestrator. +func NewSpinner(cfg ClothoConfig, agents map[string]AgentConfig) *Spinner { + return &Spinner{ + Config: cfg, + Agents: agents, + } +} + +// DeterminePlan decides if a signal requires dual-run verification based on +// the global strategy, agent configuration, and repository criticality. +func (s *Spinner) DeterminePlan(signal *jobrunner.PipelineSignal, agentName string) RunMode { + if s.Config.Strategy != "clotho-verified" { + return ModeStandard + } + + agent, ok := s.Agents[agentName] + if !ok { + return ModeStandard + } + if agent.DualRun { + return ModeDual + } + + // Protect critical repos with dual-run (Axiom 1). + if signal.RepoName == "core" || strings.Contains(signal.RepoName, "security") { + return ModeDual + } + + return ModeStandard +} + +// GetVerifierModel returns the model for the secondary "signed" verification run. +func (s *Spinner) GetVerifierModel(agentName string) string { + agent, ok := s.Agents[agentName] + if !ok || agent.VerifyModel == "" { + return "gemini-1.5-pro" + } + return agent.VerifyModel +} + +// FindByForgejoUser resolves a Forgejo username to the agent config key and config. +// This decouples agent naming (mythological roles) from Forgejo identity. +func (s *Spinner) FindByForgejoUser(forgejoUser string) (string, AgentConfig, bool) { + if forgejoUser == "" { + return "", AgentConfig{}, false + } + // Direct match on config key first. + if agent, ok := s.Agents[forgejoUser]; ok { + return forgejoUser, agent, true + } + // Search by ForgejoUser field. + for name, agent := range s.Agents { + if agent.ForgejoUser != "" && agent.ForgejoUser == forgejoUser { + return name, agent, true + } + } + return "", AgentConfig{}, false +} + +// Weave compares primary and verifier outputs. Returns true if they converge. +// This is a placeholder for future semantic diff logic. +func (s *Spinner) Weave(ctx context.Context, primaryOutput, signedOutput []byte) (bool, error) { + return string(primaryOutput) == string(signedOutput), nil +} diff --git a/agentci/config.go b/agentci/config.go new file mode 100644 index 0000000..f2297c8 --- /dev/null +++ b/agentci/config.go @@ -0,0 +1,144 @@ +// Package agentci provides configuration, security, and orchestration for AgentCI dispatch targets. +package agentci + +import ( + "fmt" + + "forge.lthn.ai/core/go/pkg/config" +) + +// AgentConfig represents a single agent machine in the config file. +type AgentConfig struct { + Host string `yaml:"host" mapstructure:"host"` + QueueDir string `yaml:"queue_dir" mapstructure:"queue_dir"` + ForgejoUser string `yaml:"forgejo_user" mapstructure:"forgejo_user"` + Model string `yaml:"model" mapstructure:"model"` // primary AI model + Runner string `yaml:"runner" mapstructure:"runner"` // runner binary: claude, codex, gemini + VerifyModel string `yaml:"verify_model" mapstructure:"verify_model"` // secondary model for dual-run + SecurityLevel string `yaml:"security_level" mapstructure:"security_level"` // low, high + Roles []string `yaml:"roles" mapstructure:"roles"` + DualRun bool `yaml:"dual_run" mapstructure:"dual_run"` + Active bool `yaml:"active" mapstructure:"active"` +} + +// ClothoConfig controls the orchestration strategy. +type ClothoConfig struct { + Strategy string `yaml:"strategy" mapstructure:"strategy"` // direct, clotho-verified + ValidationThreshold float64 `yaml:"validation_threshold" mapstructure:"validation_threshold"` // divergence limit (0.0-1.0) + SigningKeyPath string `yaml:"signing_key_path" mapstructure:"signing_key_path"` +} + +// LoadAgents reads agent targets from config and returns a map of AgentConfig. +// Returns an empty map (not an error) if no agents are configured. +func LoadAgents(cfg *config.Config) (map[string]AgentConfig, error) { + var agents map[string]AgentConfig + if err := cfg.Get("agentci.agents", &agents); err != nil { + return map[string]AgentConfig{}, nil + } + + // Validate and apply defaults. + for name, ac := range agents { + if !ac.Active { + continue + } + if ac.Host == "" { + return nil, fmt.Errorf("agent %q: host is required", name) + } + if ac.QueueDir == "" { + ac.QueueDir = "/home/claude/ai-work/queue" + } + if ac.Model == "" { + ac.Model = "sonnet" + } + if ac.Runner == "" { + ac.Runner = "claude" + } + agents[name] = ac + } + + return agents, nil +} + +// LoadActiveAgents returns only active agents. +func LoadActiveAgents(cfg *config.Config) (map[string]AgentConfig, error) { + all, err := LoadAgents(cfg) + if err != nil { + return nil, err + } + active := make(map[string]AgentConfig) + for name, ac := range all { + if ac.Active { + active[name] = ac + } + } + return active, nil +} + +// LoadClothoConfig loads the Clotho orchestrator settings. +// Returns sensible defaults if no config is present. +func LoadClothoConfig(cfg *config.Config) (ClothoConfig, error) { + var cc ClothoConfig + if err := cfg.Get("agentci.clotho", &cc); err != nil { + return ClothoConfig{ + Strategy: "direct", + ValidationThreshold: 0.85, + }, nil + } + if cc.Strategy == "" { + cc.Strategy = "direct" + } + if cc.ValidationThreshold == 0 { + cc.ValidationThreshold = 0.85 + } + return cc, nil +} + +// SaveAgent writes an agent config entry to the config file. +func SaveAgent(cfg *config.Config, name string, ac AgentConfig) error { + key := fmt.Sprintf("agentci.agents.%s", name) + data := map[string]any{ + "host": ac.Host, + "queue_dir": ac.QueueDir, + "forgejo_user": ac.ForgejoUser, + "active": ac.Active, + "dual_run": ac.DualRun, + } + if ac.Model != "" { + data["model"] = ac.Model + } + if ac.Runner != "" { + data["runner"] = ac.Runner + } + if ac.VerifyModel != "" { + data["verify_model"] = ac.VerifyModel + } + if ac.SecurityLevel != "" { + data["security_level"] = ac.SecurityLevel + } + if len(ac.Roles) > 0 { + data["roles"] = ac.Roles + } + return cfg.Set(key, data) +} + +// RemoveAgent removes an agent from the config file. +func RemoveAgent(cfg *config.Config, name string) error { + var agents map[string]AgentConfig + if err := cfg.Get("agentci.agents", &agents); err != nil { + return fmt.Errorf("no agents configured") + } + if _, ok := agents[name]; !ok { + return fmt.Errorf("agent %q not found", name) + } + delete(agents, name) + return cfg.Set("agentci.agents", agents) +} + +// ListAgents returns all configured agents (active and inactive). +func ListAgents(cfg *config.Config) (map[string]AgentConfig, error) { + var agents map[string]AgentConfig + if err := cfg.Get("agentci.agents", &agents); err != nil { + return map[string]AgentConfig{}, nil + } + return agents, nil +} diff --git a/agentci/config_test.go b/agentci/config_test.go new file mode 100644 index 0000000..7ee40ca --- /dev/null +++ b/agentci/config_test.go @@ -0,0 +1,329 @@ +package agentci + +import ( + "testing" + + "forge.lthn.ai/core/go/pkg/config" + "forge.lthn.ai/core/go/pkg/io" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func newTestConfig(t *testing.T, yaml string) *config.Config { + t.Helper() + m := io.NewMockMedium() + if yaml != "" { + m.Files["/tmp/test/config.yaml"] = yaml + } + cfg, err := config.New(config.WithMedium(m), config.WithPath("/tmp/test/config.yaml")) + require.NoError(t, err) + return cfg +} + +func TestLoadAgents_Good(t *testing.T) { + cfg := newTestConfig(t, ` +agentci: + agents: + darbs-claude: + host: claude@192.168.0.201 + queue_dir: /home/claude/ai-work/queue + forgejo_user: darbs-claude + model: sonnet + runner: claude + active: true +`) + agents, err := LoadAgents(cfg) + require.NoError(t, err) + require.Len(t, agents, 1) + + agent := agents["darbs-claude"] + assert.Equal(t, "claude@192.168.0.201", agent.Host) + assert.Equal(t, "/home/claude/ai-work/queue", agent.QueueDir) + assert.Equal(t, "sonnet", agent.Model) + assert.Equal(t, "claude", agent.Runner) +} + +func TestLoadAgents_Good_MultipleAgents(t *testing.T) { + cfg := newTestConfig(t, ` +agentci: + agents: + darbs-claude: + host: claude@192.168.0.201 + queue_dir: /home/claude/ai-work/queue + active: true + local-codex: + host: localhost + queue_dir: /home/claude/ai-work/queue + runner: codex + active: true +`) + agents, err := LoadAgents(cfg) + require.NoError(t, err) + assert.Len(t, agents, 2) + assert.Contains(t, agents, "darbs-claude") + assert.Contains(t, agents, "local-codex") +} + +func TestLoadAgents_Good_SkipsInactive(t *testing.T) { + cfg := newTestConfig(t, ` +agentci: + agents: + active-agent: + host: claude@10.0.0.1 + active: true + offline-agent: + host: claude@10.0.0.2 + active: false +`) + agents, err := LoadAgents(cfg) + require.NoError(t, err) + // Both are returned, but only active-agent has defaults applied. + assert.Len(t, agents, 2) + assert.Contains(t, agents, "active-agent") +} + +func TestLoadActiveAgents_Good(t *testing.T) { + cfg := newTestConfig(t, ` +agentci: + agents: + active-agent: + host: claude@10.0.0.1 + active: true + offline-agent: + host: claude@10.0.0.2 + active: false +`) + active, err := LoadActiveAgents(cfg) + require.NoError(t, err) + assert.Len(t, active, 1) + assert.Contains(t, active, "active-agent") +} + +func TestLoadAgents_Good_Defaults(t *testing.T) { + cfg := newTestConfig(t, ` +agentci: + agents: + minimal: + host: claude@10.0.0.1 + active: true +`) + agents, err := LoadAgents(cfg) + require.NoError(t, err) + require.Len(t, agents, 1) + + agent := agents["minimal"] + assert.Equal(t, "/home/claude/ai-work/queue", agent.QueueDir) + assert.Equal(t, "sonnet", agent.Model) + assert.Equal(t, "claude", agent.Runner) +} + +func TestLoadAgents_Good_NoConfig(t *testing.T) { + cfg := newTestConfig(t, "") + agents, err := LoadAgents(cfg) + require.NoError(t, err) + assert.Empty(t, agents) +} + +func TestLoadAgents_Bad_MissingHost(t *testing.T) { + cfg := newTestConfig(t, ` +agentci: + agents: + broken: + queue_dir: /tmp + active: true +`) + _, err := LoadAgents(cfg) + assert.Error(t, err) + assert.Contains(t, err.Error(), "host is required") +} + +func TestLoadAgents_Good_WithDualRun(t *testing.T) { + cfg := newTestConfig(t, ` +agentci: + agents: + gemini-agent: + host: localhost + runner: gemini + model: gemini-2.0-flash + verify_model: gemini-1.5-pro + dual_run: true + active: true +`) + agents, err := LoadAgents(cfg) + require.NoError(t, err) + + agent := agents["gemini-agent"] + assert.Equal(t, "gemini", agent.Runner) + assert.Equal(t, "gemini-2.0-flash", agent.Model) + assert.Equal(t, "gemini-1.5-pro", agent.VerifyModel) + assert.True(t, agent.DualRun) +} + +func TestLoadClothoConfig_Good(t *testing.T) { + cfg := newTestConfig(t, ` +agentci: + clotho: + strategy: clotho-verified + validation_threshold: 0.9 + signing_key_path: /etc/core/keys/clotho.pub +`) + cc, err := LoadClothoConfig(cfg) + require.NoError(t, err) + assert.Equal(t, "clotho-verified", cc.Strategy) + assert.Equal(t, 0.9, cc.ValidationThreshold) + assert.Equal(t, "/etc/core/keys/clotho.pub", cc.SigningKeyPath) +} + +func TestLoadClothoConfig_Good_Defaults(t *testing.T) { + cfg := newTestConfig(t, "") + cc, err := LoadClothoConfig(cfg) + require.NoError(t, err) + assert.Equal(t, "direct", cc.Strategy) + assert.Equal(t, 0.85, cc.ValidationThreshold) +} + +func TestSaveAgent_Good(t *testing.T) { + cfg := newTestConfig(t, "") + + err := SaveAgent(cfg, "new-agent", AgentConfig{ + Host: "claude@10.0.0.5", + QueueDir: "/home/claude/ai-work/queue", + ForgejoUser: "new-agent", + Model: "haiku", + Runner: "claude", + Active: true, + }) + require.NoError(t, err) + + agents, err := ListAgents(cfg) + require.NoError(t, err) + require.Contains(t, agents, "new-agent") + assert.Equal(t, "claude@10.0.0.5", agents["new-agent"].Host) + assert.Equal(t, "haiku", agents["new-agent"].Model) +} + +func TestSaveAgent_Good_WithDualRun(t *testing.T) { + cfg := newTestConfig(t, "") + + err := SaveAgent(cfg, "verified-agent", AgentConfig{ + Host: "claude@10.0.0.5", + Model: "gemini-2.0-flash", + VerifyModel: "gemini-1.5-pro", + DualRun: true, + Active: true, + }) + require.NoError(t, err) + + agents, err := ListAgents(cfg) + require.NoError(t, err) + require.Contains(t, agents, "verified-agent") + assert.True(t, agents["verified-agent"].DualRun) +} + +func TestSaveAgent_Good_OmitsEmptyOptionals(t *testing.T) { + cfg := newTestConfig(t, "") + + err := SaveAgent(cfg, "minimal", AgentConfig{ + Host: "claude@10.0.0.1", + Active: true, + }) + require.NoError(t, err) + + agents, err := ListAgents(cfg) + require.NoError(t, err) + assert.Contains(t, agents, "minimal") +} + +func TestRemoveAgent_Good(t *testing.T) { + cfg := newTestConfig(t, ` +agentci: + agents: + to-remove: + host: claude@10.0.0.1 + active: true + to-keep: + host: claude@10.0.0.2 + active: true +`) + err := RemoveAgent(cfg, "to-remove") + require.NoError(t, err) + + agents, err := ListAgents(cfg) + require.NoError(t, err) + assert.NotContains(t, agents, "to-remove") + assert.Contains(t, agents, "to-keep") +} + +func TestRemoveAgent_Bad_NotFound(t *testing.T) { + cfg := newTestConfig(t, ` +agentci: + agents: + existing: + host: claude@10.0.0.1 + active: true +`) + err := RemoveAgent(cfg, "nonexistent") + assert.Error(t, err) + assert.Contains(t, err.Error(), "not found") +} + +func TestRemoveAgent_Bad_NoAgents(t *testing.T) { + cfg := newTestConfig(t, "") + err := RemoveAgent(cfg, "anything") + assert.Error(t, err) + assert.Contains(t, err.Error(), "no agents configured") +} + +func TestListAgents_Good(t *testing.T) { + cfg := newTestConfig(t, ` +agentci: + agents: + agent-a: + host: claude@10.0.0.1 + active: true + agent-b: + host: claude@10.0.0.2 + active: false +`) + agents, err := ListAgents(cfg) + require.NoError(t, err) + assert.Len(t, agents, 2) + assert.True(t, agents["agent-a"].Active) + assert.False(t, agents["agent-b"].Active) +} + +func TestListAgents_Good_Empty(t *testing.T) { + cfg := newTestConfig(t, "") + agents, err := ListAgents(cfg) + require.NoError(t, err) + assert.Empty(t, agents) +} + +func TestRoundTrip_SaveThenLoad(t *testing.T) { + cfg := newTestConfig(t, "") + + err := SaveAgent(cfg, "alpha", AgentConfig{ + Host: "claude@alpha", + QueueDir: "/home/claude/work/queue", + ForgejoUser: "alpha-bot", + Model: "opus", + Runner: "claude", + Active: true, + }) + require.NoError(t, err) + + err = SaveAgent(cfg, "beta", AgentConfig{ + Host: "claude@beta", + ForgejoUser: "beta-bot", + Runner: "codex", + Active: true, + }) + require.NoError(t, err) + + agents, err := LoadActiveAgents(cfg) + require.NoError(t, err) + assert.Len(t, agents, 2) + assert.Equal(t, "claude@alpha", agents["alpha"].Host) + assert.Equal(t, "opus", agents["alpha"].Model) + assert.Equal(t, "codex", agents["beta"].Runner) +} diff --git a/agentci/security.go b/agentci/security.go new file mode 100644 index 0000000..f917b3f --- /dev/null +++ b/agentci/security.go @@ -0,0 +1,49 @@ +package agentci + +import ( + "fmt" + "os/exec" + "path/filepath" + "regexp" + "strings" +) + +var safeNameRegex = regexp.MustCompile(`^[a-zA-Z0-9\-\_\.]+$`) + +// SanitizePath ensures a filename or directory name is safe and prevents path traversal. +// Returns filepath.Base of the input after validation. +func SanitizePath(input string) (string, error) { + base := filepath.Base(input) + if !safeNameRegex.MatchString(base) { + return "", fmt.Errorf("invalid characters in path element: %s", input) + } + if base == "." || base == ".." || base == "/" { + return "", fmt.Errorf("invalid path element: %s", base) + } + return base, nil +} + +// EscapeShellArg wraps a string in single quotes for safe remote shell insertion. +// Prefer exec.Command arguments over constructing shell strings where possible. +func EscapeShellArg(arg string) string { + return "'" + strings.ReplaceAll(arg, "'", "'\\''") + "'" +} + +// SecureSSHCommand creates an SSH exec.Cmd with strict host key checking and batch mode. +func SecureSSHCommand(host string, remoteCmd string) *exec.Cmd { + return exec.Command("ssh", + "-o", "StrictHostKeyChecking=yes", + "-o", "BatchMode=yes", + "-o", "ConnectTimeout=10", + host, + remoteCmd, + ) +} + +// MaskToken returns a masked version of a token for safe logging. +func MaskToken(token string) string { + if len(token) < 8 { + return "*****" + } + return token[:4] + "****" + token[len(token)-4:] +} diff --git a/collect/bitcointalk.go b/collect/bitcointalk.go new file mode 100644 index 0000000..c8b3fec --- /dev/null +++ b/collect/bitcointalk.go @@ -0,0 +1,297 @@ +package collect + +import ( + "context" + "fmt" + "net/http" + "path/filepath" + "strings" + "time" + + core "forge.lthn.ai/core/go/pkg/framework/core" + "golang.org/x/net/html" +) + +// httpClient is the HTTP client used for all collection requests. +// Use SetHTTPClient to override for testing. +var httpClient = &http.Client{ + Timeout: 30 * time.Second, +} + +// BitcoinTalkCollector collects forum posts from BitcoinTalk. +type BitcoinTalkCollector struct { + // TopicID is the numeric topic identifier. + TopicID string + + // URL is a full URL to a BitcoinTalk topic page. If set, TopicID is + // extracted from it. + URL string + + // Pages limits collection to this many pages. 0 means all pages. + Pages int +} + +// Name returns the collector name. +func (b *BitcoinTalkCollector) Name() string { + id := b.TopicID + if id == "" && b.URL != "" { + id = "url" + } + return fmt.Sprintf("bitcointalk:%s", id) +} + +// Collect gathers posts from a BitcoinTalk topic. +func (b *BitcoinTalkCollector) Collect(ctx context.Context, cfg *Config) (*Result, error) { + result := &Result{Source: b.Name()} + + if cfg.Dispatcher != nil { + cfg.Dispatcher.EmitStart(b.Name(), "Starting BitcoinTalk collection") + } + + topicID := b.TopicID + if topicID == "" { + return result, core.E("collect.BitcoinTalk.Collect", "topic ID is required", nil) + } + + if cfg.DryRun { + if cfg.Dispatcher != nil { + cfg.Dispatcher.EmitProgress(b.Name(), fmt.Sprintf("[dry-run] Would collect topic %s", topicID), nil) + } + return result, nil + } + + baseDir := filepath.Join(cfg.OutputDir, "bitcointalk", topicID, "posts") + if err := cfg.Output.EnsureDir(baseDir); err != nil { + return result, core.E("collect.BitcoinTalk.Collect", "failed to create output directory", err) + } + + postNum := 0 + offset := 0 + pageCount := 0 + postsPerPage := 20 + + for { + if ctx.Err() != nil { + return result, core.E("collect.BitcoinTalk.Collect", "context cancelled", ctx.Err()) + } + + if b.Pages > 0 && pageCount >= b.Pages { + break + } + + if cfg.Limiter != nil { + if err := cfg.Limiter.Wait(ctx, "bitcointalk"); err != nil { + return result, err + } + } + + pageURL := fmt.Sprintf("https://bitcointalk.org/index.php?topic=%s.%d", topicID, offset) + + posts, err := b.fetchPage(ctx, pageURL) + if err != nil { + result.Errors++ + if cfg.Dispatcher != nil { + cfg.Dispatcher.EmitError(b.Name(), fmt.Sprintf("Failed to fetch page at offset %d: %v", offset, err), nil) + } + break + } + + if len(posts) == 0 { + break + } + + for _, post := range posts { + postNum++ + filePath := filepath.Join(baseDir, fmt.Sprintf("%d.md", postNum)) + content := formatPostMarkdown(postNum, post) + + if err := cfg.Output.Write(filePath, content); err != nil { + result.Errors++ + continue + } + + result.Items++ + result.Files = append(result.Files, filePath) + + if cfg.Dispatcher != nil { + cfg.Dispatcher.EmitItem(b.Name(), fmt.Sprintf("Post %d by %s", postNum, post.Author), nil) + } + } + + pageCount++ + offset += postsPerPage + + // If we got fewer posts than expected, we've reached the end + if len(posts) < postsPerPage { + break + } + } + + if cfg.Dispatcher != nil { + cfg.Dispatcher.EmitComplete(b.Name(), fmt.Sprintf("Collected %d posts", result.Items), result) + } + + return result, nil +} + +// btPost represents a parsed BitcoinTalk forum post. +type btPost struct { + Author string + Date string + Content string +} + +// fetchPage fetches and parses a single BitcoinTalk topic page. +func (b *BitcoinTalkCollector) fetchPage(ctx context.Context, pageURL string) ([]btPost, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, pageURL, nil) + if err != nil { + return nil, core.E("collect.BitcoinTalk.fetchPage", "failed to create request", err) + } + req.Header.Set("User-Agent", "Mozilla/5.0 (compatible; CoreCollector/1.0)") + + resp, err := httpClient.Do(req) + if err != nil { + return nil, core.E("collect.BitcoinTalk.fetchPage", "request failed", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + return nil, core.E("collect.BitcoinTalk.fetchPage", + fmt.Sprintf("unexpected status code: %d", resp.StatusCode), nil) + } + + doc, err := html.Parse(resp.Body) + if err != nil { + return nil, core.E("collect.BitcoinTalk.fetchPage", "failed to parse HTML", err) + } + + return extractPosts(doc), nil +} + +// extractPosts extracts post data from a parsed HTML document. +// It looks for the common BitcoinTalk post structure using div.post elements. +func extractPosts(doc *html.Node) []btPost { + var posts []btPost + var walk func(*html.Node) + + walk = func(n *html.Node) { + if n.Type == html.ElementNode && n.Data == "div" { + for _, attr := range n.Attr { + if attr.Key == "class" && strings.Contains(attr.Val, "post") { + post := parsePost(n) + if post.Content != "" { + posts = append(posts, post) + } + } + } + } + for c := n.FirstChild; c != nil; c = c.NextSibling { + walk(c) + } + } + + walk(doc) + return posts +} + +// parsePost extracts author, date, and content from a post div. +func parsePost(node *html.Node) btPost { + post := btPost{} + var walk func(*html.Node) + + walk = func(n *html.Node) { + if n.Type == html.ElementNode { + for _, attr := range n.Attr { + if attr.Key == "class" { + switch { + case strings.Contains(attr.Val, "poster_info"): + post.Author = extractText(n) + case strings.Contains(attr.Val, "headerandpost"): + // Look for date in smalltext + for c := n.FirstChild; c != nil; c = c.NextSibling { + if c.Type == html.ElementNode && c.Data == "div" { + for _, a := range c.Attr { + if a.Key == "class" && strings.Contains(a.Val, "smalltext") { + post.Date = strings.TrimSpace(extractText(c)) + } + } + } + } + case strings.Contains(attr.Val, "inner"): + post.Content = strings.TrimSpace(extractText(n)) + } + } + } + } + for c := n.FirstChild; c != nil; c = c.NextSibling { + walk(c) + } + } + + walk(node) + return post +} + +// extractText recursively extracts text content from an HTML node. +func extractText(n *html.Node) string { + if n.Type == html.TextNode { + return n.Data + } + + var b strings.Builder + for c := n.FirstChild; c != nil; c = c.NextSibling { + text := extractText(c) + if text != "" { + if b.Len() > 0 && c.Type == html.ElementNode && (c.Data == "br" || c.Data == "p" || c.Data == "div") { + b.WriteString("\n") + } + b.WriteString(text) + } + } + return b.String() +} + +// formatPostMarkdown formats a BitcoinTalk post as markdown. +func formatPostMarkdown(num int, post btPost) string { + var b strings.Builder + fmt.Fprintf(&b, "# Post %d by %s\n\n", num, post.Author) + + if post.Date != "" { + fmt.Fprintf(&b, "**Date:** %s\n\n", post.Date) + } + + b.WriteString(post.Content) + b.WriteString("\n") + + return b.String() +} + +// ParsePostsFromHTML parses BitcoinTalk posts from raw HTML content. +// This is exported for testing purposes. +func ParsePostsFromHTML(htmlContent string) ([]btPost, error) { + doc, err := html.Parse(strings.NewReader(htmlContent)) + if err != nil { + return nil, core.E("collect.ParsePostsFromHTML", "failed to parse HTML", err) + } + return extractPosts(doc), nil +} + +// FormatPostMarkdown is exported for testing purposes. +func FormatPostMarkdown(num int, author, date, content string) string { + return formatPostMarkdown(num, btPost{Author: author, Date: date, Content: content}) +} + +// FetchPageFunc is an injectable function type for fetching pages, used in testing. +type FetchPageFunc func(ctx context.Context, url string) ([]btPost, error) + +// BitcoinTalkCollectorWithFetcher wraps BitcoinTalkCollector with a custom fetcher for testing. +type BitcoinTalkCollectorWithFetcher struct { + BitcoinTalkCollector + Fetcher FetchPageFunc +} + +// SetHTTPClient replaces the package-level HTTP client. +// Use this in tests to inject a custom transport or timeout. +func SetHTTPClient(c *http.Client) { + httpClient = c +} diff --git a/collect/bitcointalk_test.go b/collect/bitcointalk_test.go new file mode 100644 index 0000000..69be0a7 --- /dev/null +++ b/collect/bitcointalk_test.go @@ -0,0 +1,93 @@ +package collect + +import ( + "context" + "testing" + + "forge.lthn.ai/core/go/pkg/io" + "github.com/stretchr/testify/assert" +) + +func TestBitcoinTalkCollector_Name_Good(t *testing.T) { + b := &BitcoinTalkCollector{TopicID: "12345"} + assert.Equal(t, "bitcointalk:12345", b.Name()) +} + +func TestBitcoinTalkCollector_Name_Good_URL(t *testing.T) { + b := &BitcoinTalkCollector{URL: "https://bitcointalk.org/index.php?topic=12345.0"} + assert.Equal(t, "bitcointalk:url", b.Name()) +} + +func TestBitcoinTalkCollector_Collect_Bad_NoTopicID(t *testing.T) { + m := io.NewMockMedium() + cfg := NewConfigWithMedium(m, "/output") + + b := &BitcoinTalkCollector{} + _, err := b.Collect(context.Background(), cfg) + assert.Error(t, err) +} + +func TestBitcoinTalkCollector_Collect_Good_DryRun(t *testing.T) { + m := io.NewMockMedium() + cfg := NewConfigWithMedium(m, "/output") + cfg.DryRun = true + + b := &BitcoinTalkCollector{TopicID: "12345"} + result, err := b.Collect(context.Background(), cfg) + + assert.NoError(t, err) + assert.Equal(t, 0, result.Items) +} + +func TestParsePostsFromHTML_Good(t *testing.T) { + sampleHTML := ` + +
+
satoshi
+
+
January 03, 2009
+
+
This is the first post content.
+
+
+
hal
+
+
January 10, 2009
+
+
Running bitcoin!
+
+ ` + + posts, err := ParsePostsFromHTML(sampleHTML) + assert.NoError(t, err) + assert.Len(t, posts, 2) + + assert.Contains(t, posts[0].Author, "satoshi") + assert.Contains(t, posts[0].Content, "This is the first post content.") + assert.Contains(t, posts[0].Date, "January 03, 2009") + + assert.Contains(t, posts[1].Author, "hal") + assert.Contains(t, posts[1].Content, "Running bitcoin!") +} + +func TestParsePostsFromHTML_Good_Empty(t *testing.T) { + posts, err := ParsePostsFromHTML("") + assert.NoError(t, err) + assert.Empty(t, posts) +} + +func TestFormatPostMarkdown_Good(t *testing.T) { + md := FormatPostMarkdown(1, "satoshi", "January 03, 2009", "Hello, world!") + + assert.Contains(t, md, "# Post 1 by satoshi") + assert.Contains(t, md, "**Date:** January 03, 2009") + assert.Contains(t, md, "Hello, world!") +} + +func TestFormatPostMarkdown_Good_NoDate(t *testing.T) { + md := FormatPostMarkdown(5, "user", "", "Content here") + + assert.Contains(t, md, "# Post 5 by user") + assert.NotContains(t, md, "**Date:**") + assert.Contains(t, md, "Content here") +} diff --git a/collect/collect.go b/collect/collect.go new file mode 100644 index 0000000..12d24c6 --- /dev/null +++ b/collect/collect.go @@ -0,0 +1,103 @@ +// Package collect provides a data collection subsystem for gathering information +// from multiple sources including GitHub, BitcoinTalk, CoinGecko, and academic +// paper repositories. It supports rate limiting, incremental state tracking, +// and event-driven progress reporting. +package collect + +import ( + "context" + "path/filepath" + + "forge.lthn.ai/core/go/pkg/io" +) + +// Collector is the interface all collection sources implement. +type Collector interface { + // Name returns a human-readable name for this collector. + Name() string + + // Collect gathers data from the source and writes it to the configured output. + Collect(ctx context.Context, cfg *Config) (*Result, error) +} + +// Config holds shared configuration for all collectors. +type Config struct { + // Output is the storage medium for writing collected data. + Output io.Medium + + // OutputDir is the base directory for all collected data. + OutputDir string + + // Limiter provides per-source rate limiting. + Limiter *RateLimiter + + // State tracks collection progress for incremental runs. + State *State + + // Dispatcher manages event dispatch for progress reporting. + Dispatcher *Dispatcher + + // Verbose enables detailed logging output. + Verbose bool + + // DryRun simulates collection without writing files. + DryRun bool +} + +// Result holds the output of a collection run. +type Result struct { + // Source identifies which collector produced this result. + Source string + + // Items is the number of items successfully collected. + Items int + + // Errors is the number of errors encountered during collection. + Errors int + + // Skipped is the number of items skipped (e.g. already collected). + Skipped int + + // Files lists the paths of all files written. + Files []string +} + +// NewConfig creates a Config with sensible defaults. +// It initialises a MockMedium for output if none is provided, +// sets up a rate limiter, state tracker, and event dispatcher. +func NewConfig(outputDir string) *Config { + m := io.NewMockMedium() + return &Config{ + Output: m, + OutputDir: outputDir, + Limiter: NewRateLimiter(), + State: NewState(m, filepath.Join(outputDir, ".collect-state.json")), + Dispatcher: NewDispatcher(), + } +} + +// NewConfigWithMedium creates a Config using the specified storage medium. +func NewConfigWithMedium(m io.Medium, outputDir string) *Config { + return &Config{ + Output: m, + OutputDir: outputDir, + Limiter: NewRateLimiter(), + State: NewState(m, filepath.Join(outputDir, ".collect-state.json")), + Dispatcher: NewDispatcher(), + } +} + +// MergeResults combines multiple results into a single aggregated result. +func MergeResults(source string, results ...*Result) *Result { + merged := &Result{Source: source} + for _, r := range results { + if r == nil { + continue + } + merged.Items += r.Items + merged.Errors += r.Errors + merged.Skipped += r.Skipped + merged.Files = append(merged.Files, r.Files...) + } + return merged +} diff --git a/collect/collect_test.go b/collect/collect_test.go new file mode 100644 index 0000000..cc1927b --- /dev/null +++ b/collect/collect_test.go @@ -0,0 +1,68 @@ +package collect + +import ( + "testing" + + "forge.lthn.ai/core/go/pkg/io" + "github.com/stretchr/testify/assert" +) + +func TestNewConfig_Good(t *testing.T) { + cfg := NewConfig("/tmp/output") + + assert.NotNil(t, cfg) + assert.Equal(t, "/tmp/output", cfg.OutputDir) + assert.NotNil(t, cfg.Output) + assert.NotNil(t, cfg.Limiter) + assert.NotNil(t, cfg.State) + assert.NotNil(t, cfg.Dispatcher) + assert.False(t, cfg.Verbose) + assert.False(t, cfg.DryRun) +} + +func TestNewConfigWithMedium_Good(t *testing.T) { + m := io.NewMockMedium() + cfg := NewConfigWithMedium(m, "/data") + + assert.NotNil(t, cfg) + assert.Equal(t, m, cfg.Output) + assert.Equal(t, "/data", cfg.OutputDir) + assert.NotNil(t, cfg.Limiter) + assert.NotNil(t, cfg.State) + assert.NotNil(t, cfg.Dispatcher) +} + +func TestMergeResults_Good(t *testing.T) { + r1 := &Result{ + Source: "a", + Items: 5, + Errors: 1, + Files: []string{"a.md", "b.md"}, + } + r2 := &Result{ + Source: "b", + Items: 3, + Skipped: 2, + Files: []string{"c.md"}, + } + + merged := MergeResults("combined", r1, r2) + assert.Equal(t, "combined", merged.Source) + assert.Equal(t, 8, merged.Items) + assert.Equal(t, 1, merged.Errors) + assert.Equal(t, 2, merged.Skipped) + assert.Len(t, merged.Files, 3) +} + +func TestMergeResults_Good_NilResults(t *testing.T) { + r1 := &Result{Items: 3} + merged := MergeResults("test", r1, nil, nil) + assert.Equal(t, 3, merged.Items) +} + +func TestMergeResults_Good_Empty(t *testing.T) { + merged := MergeResults("empty") + assert.Equal(t, 0, merged.Items) + assert.Equal(t, 0, merged.Errors) + assert.Nil(t, merged.Files) +} diff --git a/collect/events.go b/collect/events.go new file mode 100644 index 0000000..7083986 --- /dev/null +++ b/collect/events.go @@ -0,0 +1,133 @@ +package collect + +import ( + "sync" + "time" +) + +// Event types used by the collection subsystem. +const ( + // EventStart is emitted when a collector begins its run. + EventStart = "start" + + // EventProgress is emitted to report incremental progress. + EventProgress = "progress" + + // EventItem is emitted when a single item is collected. + EventItem = "item" + + // EventError is emitted when an error occurs during collection. + EventError = "error" + + // EventComplete is emitted when a collector finishes its run. + EventComplete = "complete" +) + +// Event represents a collection event. +type Event struct { + // Type is one of the Event* constants. + Type string `json:"type"` + + // Source identifies the collector that emitted the event. + Source string `json:"source"` + + // Message is a human-readable description of the event. + Message string `json:"message"` + + // Data carries optional event-specific payload. + Data any `json:"data,omitempty"` + + // Time is when the event occurred. + Time time.Time `json:"time"` +} + +// EventHandler handles collection events. +type EventHandler func(Event) + +// Dispatcher manages event dispatch. Handlers are registered per event type +// and are called synchronously when an event is emitted. +type Dispatcher struct { + mu sync.RWMutex + handlers map[string][]EventHandler +} + +// NewDispatcher creates a new event dispatcher. +func NewDispatcher() *Dispatcher { + return &Dispatcher{ + handlers: make(map[string][]EventHandler), + } +} + +// On registers a handler for an event type. Multiple handlers can be +// registered for the same event type and will be called in order. +func (d *Dispatcher) On(eventType string, handler EventHandler) { + d.mu.Lock() + defer d.mu.Unlock() + d.handlers[eventType] = append(d.handlers[eventType], handler) +} + +// Emit dispatches an event to all registered handlers for that event type. +// If no handlers are registered for the event type, the event is silently dropped. +// The event's Time field is set to now if it is zero. +func (d *Dispatcher) Emit(event Event) { + if event.Time.IsZero() { + event.Time = time.Now() + } + + d.mu.RLock() + handlers := d.handlers[event.Type] + d.mu.RUnlock() + + for _, h := range handlers { + h(event) + } +} + +// EmitStart emits a start event for the given source. +func (d *Dispatcher) EmitStart(source, message string) { + d.Emit(Event{ + Type: EventStart, + Source: source, + Message: message, + }) +} + +// EmitProgress emits a progress event. +func (d *Dispatcher) EmitProgress(source, message string, data any) { + d.Emit(Event{ + Type: EventProgress, + Source: source, + Message: message, + Data: data, + }) +} + +// EmitItem emits an item event. +func (d *Dispatcher) EmitItem(source, message string, data any) { + d.Emit(Event{ + Type: EventItem, + Source: source, + Message: message, + Data: data, + }) +} + +// EmitError emits an error event. +func (d *Dispatcher) EmitError(source, message string, data any) { + d.Emit(Event{ + Type: EventError, + Source: source, + Message: message, + Data: data, + }) +} + +// EmitComplete emits a complete event. +func (d *Dispatcher) EmitComplete(source, message string, data any) { + d.Emit(Event{ + Type: EventComplete, + Source: source, + Message: message, + Data: data, + }) +} diff --git a/collect/events_test.go b/collect/events_test.go new file mode 100644 index 0000000..ae9ae5d --- /dev/null +++ b/collect/events_test.go @@ -0,0 +1,133 @@ +package collect + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestDispatcher_Emit_Good(t *testing.T) { + d := NewDispatcher() + + var received Event + d.On(EventStart, func(e Event) { + received = e + }) + + d.Emit(Event{ + Type: EventStart, + Source: "test", + Message: "hello", + }) + + assert.Equal(t, EventStart, received.Type) + assert.Equal(t, "test", received.Source) + assert.Equal(t, "hello", received.Message) + assert.False(t, received.Time.IsZero(), "Time should be set automatically") +} + +func TestDispatcher_On_Good(t *testing.T) { + d := NewDispatcher() + + var count int + handler := func(e Event) { count++ } + + d.On(EventProgress, handler) + d.On(EventProgress, handler) + d.On(EventProgress, handler) + + d.Emit(Event{Type: EventProgress, Source: "test"}) + assert.Equal(t, 3, count, "All three handlers should be called") +} + +func TestDispatcher_Emit_Good_NoHandlers(t *testing.T) { + d := NewDispatcher() + + // Should not panic when emitting an event with no handlers + assert.NotPanics(t, func() { + d.Emit(Event{ + Type: "unknown-event", + Source: "test", + Message: "this should be silently dropped", + }) + }) +} + +func TestDispatcher_Emit_Good_MultipleEventTypes(t *testing.T) { + d := NewDispatcher() + + var starts, errors int + d.On(EventStart, func(e Event) { starts++ }) + d.On(EventError, func(e Event) { errors++ }) + + d.Emit(Event{Type: EventStart, Source: "test"}) + d.Emit(Event{Type: EventStart, Source: "test"}) + d.Emit(Event{Type: EventError, Source: "test"}) + + assert.Equal(t, 2, starts) + assert.Equal(t, 1, errors) +} + +func TestDispatcher_Emit_Good_SetsTime(t *testing.T) { + d := NewDispatcher() + + var received Event + d.On(EventItem, func(e Event) { + received = e + }) + + before := time.Now() + d.Emit(Event{Type: EventItem, Source: "test"}) + after := time.Now() + + assert.True(t, received.Time.After(before) || received.Time.Equal(before)) + assert.True(t, received.Time.Before(after) || received.Time.Equal(after)) +} + +func TestDispatcher_Emit_Good_PreservesExistingTime(t *testing.T) { + d := NewDispatcher() + + customTime := time.Date(2025, 6, 15, 12, 0, 0, 0, time.UTC) + var received Event + d.On(EventItem, func(e Event) { + received = e + }) + + d.Emit(Event{Type: EventItem, Source: "test", Time: customTime}) + assert.True(t, customTime.Equal(received.Time)) +} + +func TestDispatcher_EmitHelpers_Good(t *testing.T) { + d := NewDispatcher() + + events := make(map[string]Event) + for _, eventType := range []string{EventStart, EventProgress, EventItem, EventError, EventComplete} { + et := eventType + d.On(et, func(e Event) { + events[et] = e + }) + } + + d.EmitStart("s1", "started") + d.EmitProgress("s2", "progressing", map[string]int{"count": 5}) + d.EmitItem("s3", "got item", nil) + d.EmitError("s4", "something failed", nil) + d.EmitComplete("s5", "done", nil) + + assert.Equal(t, "s1", events[EventStart].Source) + assert.Equal(t, "started", events[EventStart].Message) + + assert.Equal(t, "s2", events[EventProgress].Source) + assert.NotNil(t, events[EventProgress].Data) + + assert.Equal(t, "s3", events[EventItem].Source) + assert.Equal(t, "s4", events[EventError].Source) + assert.Equal(t, "s5", events[EventComplete].Source) +} + +func TestNewDispatcher_Good(t *testing.T) { + d := NewDispatcher() + assert.NotNil(t, d) + assert.NotNil(t, d.handlers) +} diff --git a/collect/excavate.go b/collect/excavate.go new file mode 100644 index 0000000..b8b7136 --- /dev/null +++ b/collect/excavate.go @@ -0,0 +1,128 @@ +package collect + +import ( + "context" + "fmt" + "time" + + core "forge.lthn.ai/core/go/pkg/framework/core" +) + +// Excavator runs multiple collectors as a coordinated operation. +// It provides sequential execution with rate limit respect, state tracking +// for resume support, and aggregated results. +type Excavator struct { + // Collectors is the list of collectors to run. + Collectors []Collector + + // ScanOnly reports what would be collected without performing collection. + ScanOnly bool + + // Resume enables incremental collection using saved state. + Resume bool +} + +// Name returns the orchestrator name. +func (e *Excavator) Name() string { + return "excavator" +} + +// Run executes all collectors sequentially, respecting rate limits and +// using state for resume support. Results are aggregated from all collectors. +func (e *Excavator) Run(ctx context.Context, cfg *Config) (*Result, error) { + result := &Result{Source: e.Name()} + + if len(e.Collectors) == 0 { + return result, nil + } + + if cfg.Dispatcher != nil { + cfg.Dispatcher.EmitStart(e.Name(), fmt.Sprintf("Starting excavation with %d collectors", len(e.Collectors))) + } + + // Load state if resuming + if e.Resume && cfg.State != nil { + if err := cfg.State.Load(); err != nil { + return result, core.E("collect.Excavator.Run", "failed to load state", err) + } + } + + // If scan-only, just report what would be collected + if e.ScanOnly { + for _, c := range e.Collectors { + if cfg.Dispatcher != nil { + cfg.Dispatcher.EmitProgress(e.Name(), fmt.Sprintf("[scan] Would run collector: %s", c.Name()), nil) + } + } + return result, nil + } + + for i, c := range e.Collectors { + if ctx.Err() != nil { + return result, core.E("collect.Excavator.Run", "context cancelled", ctx.Err()) + } + + if cfg.Dispatcher != nil { + cfg.Dispatcher.EmitProgress(e.Name(), + fmt.Sprintf("Running collector %d/%d: %s", i+1, len(e.Collectors), c.Name()), nil) + } + + // Check if we should skip (already completed in a previous run) + if e.Resume && cfg.State != nil { + if entry, ok := cfg.State.Get(c.Name()); ok { + if entry.Items > 0 && !entry.LastRun.IsZero() { + if cfg.Dispatcher != nil { + cfg.Dispatcher.EmitProgress(e.Name(), + fmt.Sprintf("Skipping %s (already collected %d items on %s)", + c.Name(), entry.Items, entry.LastRun.Format(time.RFC3339)), nil) + } + result.Skipped++ + continue + } + } + } + + collectorResult, err := c.Collect(ctx, cfg) + if err != nil { + result.Errors++ + if cfg.Dispatcher != nil { + cfg.Dispatcher.EmitError(e.Name(), + fmt.Sprintf("Collector %s failed: %v", c.Name(), err), nil) + } + continue + } + + if collectorResult != nil { + result.Items += collectorResult.Items + result.Errors += collectorResult.Errors + result.Skipped += collectorResult.Skipped + result.Files = append(result.Files, collectorResult.Files...) + + // Update state + if cfg.State != nil { + cfg.State.Set(c.Name(), &StateEntry{ + Source: c.Name(), + LastRun: time.Now(), + Items: collectorResult.Items, + }) + } + } + } + + // Save state + if cfg.State != nil { + if err := cfg.State.Save(); err != nil { + if cfg.Dispatcher != nil { + cfg.Dispatcher.EmitError(e.Name(), fmt.Sprintf("Failed to save state: %v", err), nil) + } + } + } + + if cfg.Dispatcher != nil { + cfg.Dispatcher.EmitComplete(e.Name(), + fmt.Sprintf("Excavation complete: %d items, %d errors, %d skipped", + result.Items, result.Errors, result.Skipped), result) + } + + return result, nil +} diff --git a/collect/excavate_test.go b/collect/excavate_test.go new file mode 100644 index 0000000..2643551 --- /dev/null +++ b/collect/excavate_test.go @@ -0,0 +1,202 @@ +package collect + +import ( + "context" + "fmt" + "testing" + + "forge.lthn.ai/core/go/pkg/io" + "github.com/stretchr/testify/assert" +) + +// mockCollector is a simple collector for testing the Excavator. +type mockCollector struct { + name string + items int + err error + called bool +} + +func (m *mockCollector) Name() string { return m.name } + +func (m *mockCollector) Collect(ctx context.Context, cfg *Config) (*Result, error) { + m.called = true + if m.err != nil { + return &Result{Source: m.name, Errors: 1}, m.err + } + + result := &Result{Source: m.name, Items: m.items} + for i := 0; i < m.items; i++ { + result.Files = append(result.Files, fmt.Sprintf("/output/%s/%d.md", m.name, i)) + } + + if cfg.DryRun { + return &Result{Source: m.name}, nil + } + + return result, nil +} + +func TestExcavator_Name_Good(t *testing.T) { + e := &Excavator{} + assert.Equal(t, "excavator", e.Name()) +} + +func TestExcavator_Run_Good(t *testing.T) { + m := io.NewMockMedium() + cfg := NewConfigWithMedium(m, "/output") + cfg.Limiter = nil + + c1 := &mockCollector{name: "source-a", items: 3} + c2 := &mockCollector{name: "source-b", items: 5} + + e := &Excavator{ + Collectors: []Collector{c1, c2}, + } + + result, err := e.Run(context.Background(), cfg) + + assert.NoError(t, err) + assert.True(t, c1.called) + assert.True(t, c2.called) + assert.Equal(t, 8, result.Items) + assert.Len(t, result.Files, 8) +} + +func TestExcavator_Run_Good_Empty(t *testing.T) { + m := io.NewMockMedium() + cfg := NewConfigWithMedium(m, "/output") + + e := &Excavator{} + result, err := e.Run(context.Background(), cfg) + + assert.NoError(t, err) + assert.Equal(t, 0, result.Items) +} + +func TestExcavator_Run_Good_DryRun(t *testing.T) { + m := io.NewMockMedium() + cfg := NewConfigWithMedium(m, "/output") + cfg.DryRun = true + + c1 := &mockCollector{name: "source-a", items: 10} + c2 := &mockCollector{name: "source-b", items: 20} + + e := &Excavator{ + Collectors: []Collector{c1, c2}, + } + + result, err := e.Run(context.Background(), cfg) + + assert.NoError(t, err) + assert.True(t, c1.called) + assert.True(t, c2.called) + // In dry run, mockCollector returns 0 items + assert.Equal(t, 0, result.Items) +} + +func TestExcavator_Run_Good_ScanOnly(t *testing.T) { + m := io.NewMockMedium() + cfg := NewConfigWithMedium(m, "/output") + + c1 := &mockCollector{name: "source-a", items: 10} + + var progressMessages []string + cfg.Dispatcher.On(EventProgress, func(e Event) { + progressMessages = append(progressMessages, e.Message) + }) + + e := &Excavator{ + Collectors: []Collector{c1}, + ScanOnly: true, + } + + result, err := e.Run(context.Background(), cfg) + + assert.NoError(t, err) + assert.False(t, c1.called, "Collector should not be called in scan-only mode") + assert.Equal(t, 0, result.Items) + assert.NotEmpty(t, progressMessages) + assert.Contains(t, progressMessages[0], "source-a") +} + +func TestExcavator_Run_Good_WithErrors(t *testing.T) { + m := io.NewMockMedium() + cfg := NewConfigWithMedium(m, "/output") + cfg.Limiter = nil + + c1 := &mockCollector{name: "good", items: 5} + c2 := &mockCollector{name: "bad", err: fmt.Errorf("network error")} + c3 := &mockCollector{name: "also-good", items: 3} + + e := &Excavator{ + Collectors: []Collector{c1, c2, c3}, + } + + result, err := e.Run(context.Background(), cfg) + + assert.NoError(t, err) + assert.Equal(t, 8, result.Items) + assert.Equal(t, 1, result.Errors) // c2 failed + assert.True(t, c1.called) + assert.True(t, c2.called) + assert.True(t, c3.called) +} + +func TestExcavator_Run_Good_CancelledContext(t *testing.T) { + m := io.NewMockMedium() + cfg := NewConfigWithMedium(m, "/output") + + ctx, cancel := context.WithCancel(context.Background()) + cancel() // Cancel immediately + + c1 := &mockCollector{name: "source-a", items: 5} + + e := &Excavator{ + Collectors: []Collector{c1}, + } + + _, err := e.Run(ctx, cfg) + assert.Error(t, err) +} + +func TestExcavator_Run_Good_SavesState(t *testing.T) { + m := io.NewMockMedium() + cfg := NewConfigWithMedium(m, "/output") + cfg.Limiter = nil + + c1 := &mockCollector{name: "source-a", items: 5} + + e := &Excavator{ + Collectors: []Collector{c1}, + } + + _, err := e.Run(context.Background(), cfg) + assert.NoError(t, err) + + // Verify state was saved + entry, ok := cfg.State.Get("source-a") + assert.True(t, ok) + assert.Equal(t, 5, entry.Items) + assert.Equal(t, "source-a", entry.Source) +} + +func TestExcavator_Run_Good_Events(t *testing.T) { + m := io.NewMockMedium() + cfg := NewConfigWithMedium(m, "/output") + cfg.Limiter = nil + + var startCount, completeCount int + cfg.Dispatcher.On(EventStart, func(e Event) { startCount++ }) + cfg.Dispatcher.On(EventComplete, func(e Event) { completeCount++ }) + + c1 := &mockCollector{name: "source-a", items: 1} + e := &Excavator{ + Collectors: []Collector{c1}, + } + + _, err := e.Run(context.Background(), cfg) + assert.NoError(t, err) + assert.Equal(t, 1, startCount) + assert.Equal(t, 1, completeCount) +} diff --git a/collect/github.go b/collect/github.go new file mode 100644 index 0000000..7a04a8f --- /dev/null +++ b/collect/github.go @@ -0,0 +1,289 @@ +package collect + +import ( + "context" + "encoding/json" + "fmt" + "os/exec" + "path/filepath" + "strings" + "time" + + core "forge.lthn.ai/core/go/pkg/framework/core" +) + +// ghIssue represents a GitHub issue or pull request as returned by the gh CLI. +type ghIssue struct { + Number int `json:"number"` + Title string `json:"title"` + State string `json:"state"` + Author ghAuthor `json:"author"` + Body string `json:"body"` + CreatedAt time.Time `json:"createdAt"` + Labels []ghLabel `json:"labels"` + URL string `json:"url"` +} + +type ghAuthor struct { + Login string `json:"login"` +} + +type ghLabel struct { + Name string `json:"name"` +} + +// ghRepo represents a GitHub repository as returned by the gh CLI. +type ghRepo struct { + Name string `json:"name"` +} + +// GitHubCollector collects issues and PRs from GitHub repositories. +type GitHubCollector struct { + // Org is the GitHub organisation. + Org string + + // Repo is the repository name. If empty and Org is set, all repos are collected. + Repo string + + // IssuesOnly limits collection to issues (excludes PRs). + IssuesOnly bool + + // PRsOnly limits collection to PRs (excludes issues). + PRsOnly bool +} + +// Name returns the collector name. +func (g *GitHubCollector) Name() string { + if g.Repo != "" { + return fmt.Sprintf("github:%s/%s", g.Org, g.Repo) + } + return fmt.Sprintf("github:%s", g.Org) +} + +// Collect gathers issues and/or PRs from GitHub repositories. +func (g *GitHubCollector) Collect(ctx context.Context, cfg *Config) (*Result, error) { + result := &Result{Source: g.Name()} + + if cfg.Dispatcher != nil { + cfg.Dispatcher.EmitStart(g.Name(), "Starting GitHub collection") + } + + // If no specific repo, list all repos in the org + repos := []string{g.Repo} + if g.Repo == "" { + var err error + repos, err = g.listOrgRepos(ctx) + if err != nil { + return result, err + } + } + + for _, repo := range repos { + if ctx.Err() != nil { + return result, core.E("collect.GitHub.Collect", "context cancelled", ctx.Err()) + } + + if !g.PRsOnly { + issueResult, err := g.collectIssues(ctx, cfg, repo) + if err != nil { + result.Errors++ + if cfg.Dispatcher != nil { + cfg.Dispatcher.EmitError(g.Name(), fmt.Sprintf("Error collecting issues for %s: %v", repo, err), nil) + } + } else { + result.Items += issueResult.Items + result.Skipped += issueResult.Skipped + result.Files = append(result.Files, issueResult.Files...) + } + } + + if !g.IssuesOnly { + prResult, err := g.collectPRs(ctx, cfg, repo) + if err != nil { + result.Errors++ + if cfg.Dispatcher != nil { + cfg.Dispatcher.EmitError(g.Name(), fmt.Sprintf("Error collecting PRs for %s: %v", repo, err), nil) + } + } else { + result.Items += prResult.Items + result.Skipped += prResult.Skipped + result.Files = append(result.Files, prResult.Files...) + } + } + } + + if cfg.Dispatcher != nil { + cfg.Dispatcher.EmitComplete(g.Name(), fmt.Sprintf("Collected %d items", result.Items), result) + } + + return result, nil +} + +// listOrgRepos returns all repository names for the configured org. +func (g *GitHubCollector) listOrgRepos(ctx context.Context) ([]string, error) { + cmd := exec.CommandContext(ctx, "gh", "repo", "list", g.Org, + "--json", "name", + "--limit", "1000", + ) + out, err := cmd.Output() + if err != nil { + return nil, core.E("collect.GitHub.listOrgRepos", "failed to list repos", err) + } + + var repos []ghRepo + if err := json.Unmarshal(out, &repos); err != nil { + return nil, core.E("collect.GitHub.listOrgRepos", "failed to parse repo list", err) + } + + names := make([]string, len(repos)) + for i, r := range repos { + names[i] = r.Name + } + return names, nil +} + +// collectIssues collects issues for a single repository. +func (g *GitHubCollector) collectIssues(ctx context.Context, cfg *Config, repo string) (*Result, error) { + result := &Result{Source: fmt.Sprintf("github:%s/%s/issues", g.Org, repo)} + + if cfg.DryRun { + if cfg.Dispatcher != nil { + cfg.Dispatcher.EmitProgress(g.Name(), fmt.Sprintf("[dry-run] Would collect issues for %s/%s", g.Org, repo), nil) + } + return result, nil + } + + if cfg.Limiter != nil { + if err := cfg.Limiter.Wait(ctx, "github"); err != nil { + return result, err + } + } + + repoRef := fmt.Sprintf("%s/%s", g.Org, repo) + cmd := exec.CommandContext(ctx, "gh", "issue", "list", + "--repo", repoRef, + "--json", "number,title,state,author,body,createdAt,labels,url", + "--limit", "100", + "--state", "all", + ) + out, err := cmd.Output() + if err != nil { + return result, core.E("collect.GitHub.collectIssues", "gh issue list failed for "+repoRef, err) + } + + var issues []ghIssue + if err := json.Unmarshal(out, &issues); err != nil { + return result, core.E("collect.GitHub.collectIssues", "failed to parse issues", err) + } + + baseDir := filepath.Join(cfg.OutputDir, "github", g.Org, repo, "issues") + if err := cfg.Output.EnsureDir(baseDir); err != nil { + return result, core.E("collect.GitHub.collectIssues", "failed to create output directory", err) + } + + for _, issue := range issues { + filePath := filepath.Join(baseDir, fmt.Sprintf("%d.md", issue.Number)) + content := formatIssueMarkdown(issue) + + if err := cfg.Output.Write(filePath, content); err != nil { + result.Errors++ + continue + } + + result.Items++ + result.Files = append(result.Files, filePath) + + if cfg.Dispatcher != nil { + cfg.Dispatcher.EmitItem(g.Name(), fmt.Sprintf("Issue #%d: %s", issue.Number, issue.Title), nil) + } + } + + return result, nil +} + +// collectPRs collects pull requests for a single repository. +func (g *GitHubCollector) collectPRs(ctx context.Context, cfg *Config, repo string) (*Result, error) { + result := &Result{Source: fmt.Sprintf("github:%s/%s/pulls", g.Org, repo)} + + if cfg.DryRun { + if cfg.Dispatcher != nil { + cfg.Dispatcher.EmitProgress(g.Name(), fmt.Sprintf("[dry-run] Would collect PRs for %s/%s", g.Org, repo), nil) + } + return result, nil + } + + if cfg.Limiter != nil { + if err := cfg.Limiter.Wait(ctx, "github"); err != nil { + return result, err + } + } + + repoRef := fmt.Sprintf("%s/%s", g.Org, repo) + cmd := exec.CommandContext(ctx, "gh", "pr", "list", + "--repo", repoRef, + "--json", "number,title,state,author,body,createdAt,labels,url", + "--limit", "100", + "--state", "all", + ) + out, err := cmd.Output() + if err != nil { + return result, core.E("collect.GitHub.collectPRs", "gh pr list failed for "+repoRef, err) + } + + var prs []ghIssue + if err := json.Unmarshal(out, &prs); err != nil { + return result, core.E("collect.GitHub.collectPRs", "failed to parse pull requests", err) + } + + baseDir := filepath.Join(cfg.OutputDir, "github", g.Org, repo, "pulls") + if err := cfg.Output.EnsureDir(baseDir); err != nil { + return result, core.E("collect.GitHub.collectPRs", "failed to create output directory", err) + } + + for _, pr := range prs { + filePath := filepath.Join(baseDir, fmt.Sprintf("%d.md", pr.Number)) + content := formatIssueMarkdown(pr) + + if err := cfg.Output.Write(filePath, content); err != nil { + result.Errors++ + continue + } + + result.Items++ + result.Files = append(result.Files, filePath) + + if cfg.Dispatcher != nil { + cfg.Dispatcher.EmitItem(g.Name(), fmt.Sprintf("PR #%d: %s", pr.Number, pr.Title), nil) + } + } + + return result, nil +} + +// formatIssueMarkdown formats a GitHub issue or PR as markdown. +func formatIssueMarkdown(issue ghIssue) string { + var b strings.Builder + fmt.Fprintf(&b, "# %s\n\n", issue.Title) + fmt.Fprintf(&b, "- **Number:** #%d\n", issue.Number) + fmt.Fprintf(&b, "- **State:** %s\n", issue.State) + fmt.Fprintf(&b, "- **Author:** %s\n", issue.Author.Login) + fmt.Fprintf(&b, "- **Created:** %s\n", issue.CreatedAt.Format(time.RFC3339)) + + if len(issue.Labels) > 0 { + labels := make([]string, len(issue.Labels)) + for i, l := range issue.Labels { + labels[i] = l.Name + } + fmt.Fprintf(&b, "- **Labels:** %s\n", strings.Join(labels, ", ")) + } + + if issue.URL != "" { + fmt.Fprintf(&b, "- **URL:** %s\n", issue.URL) + } + + if issue.Body != "" { + fmt.Fprintf(&b, "\n%s\n", issue.Body) + } + + return b.String() +} diff --git a/collect/github_test.go b/collect/github_test.go new file mode 100644 index 0000000..a2fd1d1 --- /dev/null +++ b/collect/github_test.go @@ -0,0 +1,103 @@ +package collect + +import ( + "context" + "testing" + "time" + + "forge.lthn.ai/core/go/pkg/io" + "github.com/stretchr/testify/assert" +) + +func TestGitHubCollector_Name_Good(t *testing.T) { + g := &GitHubCollector{Org: "host-uk", Repo: "core"} + assert.Equal(t, "github:host-uk/core", g.Name()) +} + +func TestGitHubCollector_Name_Good_OrgOnly(t *testing.T) { + g := &GitHubCollector{Org: "host-uk"} + assert.Equal(t, "github:host-uk", g.Name()) +} + +func TestGitHubCollector_Collect_Good_DryRun(t *testing.T) { + m := io.NewMockMedium() + cfg := NewConfigWithMedium(m, "/output") + cfg.DryRun = true + + var progressEmitted bool + cfg.Dispatcher.On(EventProgress, func(e Event) { + progressEmitted = true + }) + + g := &GitHubCollector{Org: "host-uk", Repo: "core"} + result, err := g.Collect(context.Background(), cfg) + + assert.NoError(t, err) + assert.NotNil(t, result) + assert.Equal(t, 0, result.Items) + assert.True(t, progressEmitted, "Should emit progress event in dry-run mode") +} + +func TestGitHubCollector_Collect_Good_DryRun_IssuesOnly(t *testing.T) { + m := io.NewMockMedium() + cfg := NewConfigWithMedium(m, "/output") + cfg.DryRun = true + + g := &GitHubCollector{Org: "test-org", Repo: "test-repo", IssuesOnly: true} + result, err := g.Collect(context.Background(), cfg) + + assert.NoError(t, err) + assert.Equal(t, 0, result.Items) +} + +func TestGitHubCollector_Collect_Good_DryRun_PRsOnly(t *testing.T) { + m := io.NewMockMedium() + cfg := NewConfigWithMedium(m, "/output") + cfg.DryRun = true + + g := &GitHubCollector{Org: "test-org", Repo: "test-repo", PRsOnly: true} + result, err := g.Collect(context.Background(), cfg) + + assert.NoError(t, err) + assert.Equal(t, 0, result.Items) +} + +func TestFormatIssueMarkdown_Good(t *testing.T) { + issue := ghIssue{ + Number: 42, + Title: "Test Issue", + State: "open", + Author: ghAuthor{Login: "testuser"}, + Body: "This is the body.", + CreatedAt: time.Date(2025, 1, 15, 10, 0, 0, 0, time.UTC), + Labels: []ghLabel{ + {Name: "bug"}, + {Name: "priority"}, + }, + URL: "https://github.com/test/repo/issues/42", + } + + md := formatIssueMarkdown(issue) + + assert.Contains(t, md, "# Test Issue") + assert.Contains(t, md, "**Number:** #42") + assert.Contains(t, md, "**State:** open") + assert.Contains(t, md, "**Author:** testuser") + assert.Contains(t, md, "**Labels:** bug, priority") + assert.Contains(t, md, "This is the body.") + assert.Contains(t, md, "**URL:** https://github.com/test/repo/issues/42") +} + +func TestFormatIssueMarkdown_Good_NoLabels(t *testing.T) { + issue := ghIssue{ + Number: 1, + Title: "Simple", + State: "closed", + Author: ghAuthor{Login: "user"}, + } + + md := formatIssueMarkdown(issue) + + assert.Contains(t, md, "# Simple") + assert.NotContains(t, md, "**Labels:**") +} diff --git a/collect/market.go b/collect/market.go new file mode 100644 index 0000000..30312ea --- /dev/null +++ b/collect/market.go @@ -0,0 +1,277 @@ +package collect + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "path/filepath" + "strings" + "time" + + core "forge.lthn.ai/core/go/pkg/framework/core" +) + +// coinGeckoBaseURL is the base URL for the CoinGecko API. +// It is a variable so it can be overridden in tests. +var coinGeckoBaseURL = "https://api.coingecko.com/api/v3" + +// MarketCollector collects market data from CoinGecko. +type MarketCollector struct { + // CoinID is the CoinGecko coin identifier (e.g. "bitcoin", "ethereum"). + CoinID string + + // Historical enables collection of historical market chart data. + Historical bool + + // FromDate is the start date for historical data in YYYY-MM-DD format. + FromDate string +} + +// Name returns the collector name. +func (m *MarketCollector) Name() string { + return fmt.Sprintf("market:%s", m.CoinID) +} + +// coinData represents the current coin data from CoinGecko. +type coinData struct { + ID string `json:"id"` + Symbol string `json:"symbol"` + Name string `json:"name"` + MarketData marketData `json:"market_data"` +} + +type marketData struct { + CurrentPrice map[string]float64 `json:"current_price"` + MarketCap map[string]float64 `json:"market_cap"` + TotalVolume map[string]float64 `json:"total_volume"` + High24h map[string]float64 `json:"high_24h"` + Low24h map[string]float64 `json:"low_24h"` + PriceChange24h float64 `json:"price_change_24h"` + PriceChangePct24h float64 `json:"price_change_percentage_24h"` + MarketCapRank int `json:"market_cap_rank"` + TotalSupply float64 `json:"total_supply"` + CirculatingSupply float64 `json:"circulating_supply"` + LastUpdated string `json:"last_updated"` +} + +// historicalData represents historical market chart data from CoinGecko. +type historicalData struct { + Prices [][]float64 `json:"prices"` + MarketCaps [][]float64 `json:"market_caps"` + TotalVolumes [][]float64 `json:"total_volumes"` +} + +// Collect gathers market data from CoinGecko. +func (m *MarketCollector) Collect(ctx context.Context, cfg *Config) (*Result, error) { + result := &Result{Source: m.Name()} + + if m.CoinID == "" { + return result, core.E("collect.Market.Collect", "coin ID is required", nil) + } + + if cfg.Dispatcher != nil { + cfg.Dispatcher.EmitStart(m.Name(), fmt.Sprintf("Starting market data collection for %s", m.CoinID)) + } + + if cfg.DryRun { + if cfg.Dispatcher != nil { + cfg.Dispatcher.EmitProgress(m.Name(), fmt.Sprintf("[dry-run] Would collect market data for %s", m.CoinID), nil) + } + return result, nil + } + + baseDir := filepath.Join(cfg.OutputDir, "market", m.CoinID) + if err := cfg.Output.EnsureDir(baseDir); err != nil { + return result, core.E("collect.Market.Collect", "failed to create output directory", err) + } + + // Collect current data + currentResult, err := m.collectCurrent(ctx, cfg, baseDir) + if err != nil { + result.Errors++ + if cfg.Dispatcher != nil { + cfg.Dispatcher.EmitError(m.Name(), fmt.Sprintf("Failed to collect current data: %v", err), nil) + } + } else { + result.Items += currentResult.Items + result.Files = append(result.Files, currentResult.Files...) + } + + // Collect historical data if requested + if m.Historical { + histResult, err := m.collectHistorical(ctx, cfg, baseDir) + if err != nil { + result.Errors++ + if cfg.Dispatcher != nil { + cfg.Dispatcher.EmitError(m.Name(), fmt.Sprintf("Failed to collect historical data: %v", err), nil) + } + } else { + result.Items += histResult.Items + result.Files = append(result.Files, histResult.Files...) + } + } + + if cfg.Dispatcher != nil { + cfg.Dispatcher.EmitComplete(m.Name(), fmt.Sprintf("Collected market data for %s", m.CoinID), result) + } + + return result, nil +} + +// collectCurrent fetches current coin data from CoinGecko. +func (m *MarketCollector) collectCurrent(ctx context.Context, cfg *Config, baseDir string) (*Result, error) { + result := &Result{Source: m.Name()} + + if cfg.Limiter != nil { + if err := cfg.Limiter.Wait(ctx, "coingecko"); err != nil { + return result, err + } + } + + url := fmt.Sprintf("%s/coins/%s", coinGeckoBaseURL, m.CoinID) + data, err := fetchJSON[coinData](ctx, url) + if err != nil { + return result, core.E("collect.Market.collectCurrent", "failed to fetch coin data", err) + } + + // Write raw JSON + jsonBytes, err := json.MarshalIndent(data, "", " ") + if err != nil { + return result, core.E("collect.Market.collectCurrent", "failed to marshal data", err) + } + + jsonPath := filepath.Join(baseDir, "current.json") + if err := cfg.Output.Write(jsonPath, string(jsonBytes)); err != nil { + return result, core.E("collect.Market.collectCurrent", "failed to write JSON", err) + } + result.Items++ + result.Files = append(result.Files, jsonPath) + + // Write summary markdown + summary := formatMarketSummary(data) + summaryPath := filepath.Join(baseDir, "summary.md") + if err := cfg.Output.Write(summaryPath, summary); err != nil { + return result, core.E("collect.Market.collectCurrent", "failed to write summary", err) + } + result.Items++ + result.Files = append(result.Files, summaryPath) + + return result, nil +} + +// collectHistorical fetches historical market chart data from CoinGecko. +func (m *MarketCollector) collectHistorical(ctx context.Context, cfg *Config, baseDir string) (*Result, error) { + result := &Result{Source: m.Name()} + + if cfg.Limiter != nil { + if err := cfg.Limiter.Wait(ctx, "coingecko"); err != nil { + return result, err + } + } + + days := "365" + if m.FromDate != "" { + fromTime, err := time.Parse("2006-01-02", m.FromDate) + if err == nil { + dayCount := int(time.Since(fromTime).Hours() / 24) + if dayCount > 0 { + days = fmt.Sprintf("%d", dayCount) + } + } + } + + url := fmt.Sprintf("%s/coins/%s/market_chart?vs_currency=usd&days=%s", coinGeckoBaseURL, m.CoinID, days) + data, err := fetchJSON[historicalData](ctx, url) + if err != nil { + return result, core.E("collect.Market.collectHistorical", "failed to fetch historical data", err) + } + + jsonBytes, err := json.MarshalIndent(data, "", " ") + if err != nil { + return result, core.E("collect.Market.collectHistorical", "failed to marshal data", err) + } + + jsonPath := filepath.Join(baseDir, "historical.json") + if err := cfg.Output.Write(jsonPath, string(jsonBytes)); err != nil { + return result, core.E("collect.Market.collectHistorical", "failed to write JSON", err) + } + result.Items++ + result.Files = append(result.Files, jsonPath) + + return result, nil +} + +// fetchJSON fetches JSON from a URL and unmarshals it into the given type. +func fetchJSON[T any](ctx context.Context, url string) (*T, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return nil, core.E("collect.fetchJSON", "failed to create request", err) + } + req.Header.Set("User-Agent", "CoreCollector/1.0") + req.Header.Set("Accept", "application/json") + + resp, err := httpClient.Do(req) + if err != nil { + return nil, core.E("collect.fetchJSON", "request failed", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + return nil, core.E("collect.fetchJSON", + fmt.Sprintf("unexpected status code: %d for %s", resp.StatusCode, url), nil) + } + + var data T + if err := json.NewDecoder(resp.Body).Decode(&data); err != nil { + return nil, core.E("collect.fetchJSON", "failed to decode response", err) + } + + return &data, nil +} + +// formatMarketSummary formats coin data as a markdown summary. +func formatMarketSummary(data *coinData) string { + var b strings.Builder + fmt.Fprintf(&b, "# %s (%s)\n\n", data.Name, strings.ToUpper(data.Symbol)) + + md := data.MarketData + + if price, ok := md.CurrentPrice["usd"]; ok { + fmt.Fprintf(&b, "- **Current Price (USD):** $%.2f\n", price) + } + if cap, ok := md.MarketCap["usd"]; ok { + fmt.Fprintf(&b, "- **Market Cap (USD):** $%.0f\n", cap) + } + if vol, ok := md.TotalVolume["usd"]; ok { + fmt.Fprintf(&b, "- **24h Volume (USD):** $%.0f\n", vol) + } + if high, ok := md.High24h["usd"]; ok { + fmt.Fprintf(&b, "- **24h High (USD):** $%.2f\n", high) + } + if low, ok := md.Low24h["usd"]; ok { + fmt.Fprintf(&b, "- **24h Low (USD):** $%.2f\n", low) + } + + fmt.Fprintf(&b, "- **24h Price Change:** $%.2f (%.2f%%)\n", md.PriceChange24h, md.PriceChangePct24h) + + if md.MarketCapRank > 0 { + fmt.Fprintf(&b, "- **Market Cap Rank:** #%d\n", md.MarketCapRank) + } + if md.CirculatingSupply > 0 { + fmt.Fprintf(&b, "- **Circulating Supply:** %.0f\n", md.CirculatingSupply) + } + if md.TotalSupply > 0 { + fmt.Fprintf(&b, "- **Total Supply:** %.0f\n", md.TotalSupply) + } + if md.LastUpdated != "" { + fmt.Fprintf(&b, "\n*Last updated: %s*\n", md.LastUpdated) + } + + return b.String() +} + +// FormatMarketSummary is exported for testing. +func FormatMarketSummary(data *coinData) string { + return formatMarketSummary(data) +} diff --git a/collect/market_test.go b/collect/market_test.go new file mode 100644 index 0000000..0f4097d --- /dev/null +++ b/collect/market_test.go @@ -0,0 +1,187 @@ +package collect + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "forge.lthn.ai/core/go/pkg/io" + "github.com/stretchr/testify/assert" +) + +func TestMarketCollector_Name_Good(t *testing.T) { + m := &MarketCollector{CoinID: "bitcoin"} + assert.Equal(t, "market:bitcoin", m.Name()) +} + +func TestMarketCollector_Collect_Bad_NoCoinID(t *testing.T) { + mock := io.NewMockMedium() + cfg := NewConfigWithMedium(mock, "/output") + + m := &MarketCollector{} + _, err := m.Collect(context.Background(), cfg) + assert.Error(t, err) +} + +func TestMarketCollector_Collect_Good_DryRun(t *testing.T) { + mock := io.NewMockMedium() + cfg := NewConfigWithMedium(mock, "/output") + cfg.DryRun = true + + m := &MarketCollector{CoinID: "bitcoin"} + result, err := m.Collect(context.Background(), cfg) + + assert.NoError(t, err) + assert.Equal(t, 0, result.Items) +} + +func TestMarketCollector_Collect_Good_CurrentData(t *testing.T) { + // Set up a mock CoinGecko server + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + data := coinData{ + ID: "bitcoin", + Symbol: "btc", + Name: "Bitcoin", + MarketData: marketData{ + CurrentPrice: map[string]float64{"usd": 42000.50}, + MarketCap: map[string]float64{"usd": 800000000000}, + TotalVolume: map[string]float64{"usd": 25000000000}, + High24h: map[string]float64{"usd": 43000}, + Low24h: map[string]float64{"usd": 41000}, + PriceChange24h: 500.25, + PriceChangePct24h: 1.2, + MarketCapRank: 1, + CirculatingSupply: 19500000, + TotalSupply: 21000000, + LastUpdated: "2025-01-15T10:00:00Z", + }, + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(data) + })) + defer server.Close() + + // Override base URL + oldURL := coinGeckoBaseURL + coinGeckoBaseURL = server.URL + defer func() { coinGeckoBaseURL = oldURL }() + + mock := io.NewMockMedium() + cfg := NewConfigWithMedium(mock, "/output") + // Disable rate limiter to avoid delays in tests + cfg.Limiter = nil + + m := &MarketCollector{CoinID: "bitcoin"} + result, err := m.Collect(context.Background(), cfg) + + assert.NoError(t, err) + assert.Equal(t, 2, result.Items) // current.json + summary.md + assert.Len(t, result.Files, 2) + + // Verify current.json was written + content, err := mock.Read("/output/market/bitcoin/current.json") + assert.NoError(t, err) + assert.Contains(t, content, "bitcoin") + + // Verify summary.md was written + summary, err := mock.Read("/output/market/bitcoin/summary.md") + assert.NoError(t, err) + assert.Contains(t, summary, "Bitcoin") + assert.Contains(t, summary, "42000.50") +} + +func TestMarketCollector_Collect_Good_Historical(t *testing.T) { + callCount := 0 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + callCount++ + w.Header().Set("Content-Type", "application/json") + + if callCount == 1 { + // Current data response + data := coinData{ + ID: "ethereum", + Symbol: "eth", + Name: "Ethereum", + MarketData: marketData{ + CurrentPrice: map[string]float64{"usd": 3000}, + }, + } + _ = json.NewEncoder(w).Encode(data) + } else { + // Historical data response + data := historicalData{ + Prices: [][]float64{{1705305600000, 3000.0}, {1705392000000, 3100.0}}, + MarketCaps: [][]float64{{1705305600000, 360000000000}}, + TotalVolumes: [][]float64{{1705305600000, 15000000000}}, + } + _ = json.NewEncoder(w).Encode(data) + } + })) + defer server.Close() + + oldURL := coinGeckoBaseURL + coinGeckoBaseURL = server.URL + defer func() { coinGeckoBaseURL = oldURL }() + + mock := io.NewMockMedium() + cfg := NewConfigWithMedium(mock, "/output") + cfg.Limiter = nil + + m := &MarketCollector{CoinID: "ethereum", Historical: true} + result, err := m.Collect(context.Background(), cfg) + + assert.NoError(t, err) + assert.Equal(t, 3, result.Items) // current.json + summary.md + historical.json + assert.Len(t, result.Files, 3) + + // Verify historical.json was written + content, err := mock.Read("/output/market/ethereum/historical.json") + assert.NoError(t, err) + assert.Contains(t, content, "3000") +} + +func TestFormatMarketSummary_Good(t *testing.T) { + data := &coinData{ + Name: "Bitcoin", + Symbol: "btc", + MarketData: marketData{ + CurrentPrice: map[string]float64{"usd": 50000}, + MarketCap: map[string]float64{"usd": 1000000000000}, + MarketCapRank: 1, + CirculatingSupply: 19500000, + TotalSupply: 21000000, + }, + } + + summary := FormatMarketSummary(data) + + assert.Contains(t, summary, "# Bitcoin (BTC)") + assert.Contains(t, summary, "$50000.00") + assert.Contains(t, summary, "Market Cap Rank:** #1") + assert.Contains(t, summary, "Circulating Supply") + assert.Contains(t, summary, "Total Supply") +} + +func TestMarketCollector_Collect_Bad_ServerError(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + defer server.Close() + + oldURL := coinGeckoBaseURL + coinGeckoBaseURL = server.URL + defer func() { coinGeckoBaseURL = oldURL }() + + mock := io.NewMockMedium() + cfg := NewConfigWithMedium(mock, "/output") + cfg.Limiter = nil + + m := &MarketCollector{CoinID: "bitcoin"} + result, err := m.Collect(context.Background(), cfg) + + // Should have errors but not fail entirely + assert.NoError(t, err) + assert.Equal(t, 1, result.Errors) +} diff --git a/collect/papers.go b/collect/papers.go new file mode 100644 index 0000000..9c2a3fc --- /dev/null +++ b/collect/papers.go @@ -0,0 +1,402 @@ +package collect + +import ( + "context" + "encoding/xml" + "fmt" + "net/http" + "net/url" + "path/filepath" + "strings" + + core "forge.lthn.ai/core/go/pkg/framework/core" + "golang.org/x/net/html" +) + +// Paper source identifiers. +const ( + PaperSourceIACR = "iacr" + PaperSourceArXiv = "arxiv" + PaperSourceAll = "all" +) + +// PapersCollector collects papers from IACR and arXiv. +type PapersCollector struct { + // Source is one of PaperSourceIACR, PaperSourceArXiv, or PaperSourceAll. + Source string + + // Category is the arXiv category (e.g. "cs.CR" for cryptography). + Category string + + // Query is the search query string. + Query string +} + +// Name returns the collector name. +func (p *PapersCollector) Name() string { + return fmt.Sprintf("papers:%s", p.Source) +} + +// paper represents a parsed academic paper. +type paper struct { + ID string + Title string + Authors []string + Abstract string + Date string + URL string + Source string +} + +// Collect gathers papers from the configured sources. +func (p *PapersCollector) Collect(ctx context.Context, cfg *Config) (*Result, error) { + result := &Result{Source: p.Name()} + + if p.Query == "" { + return result, core.E("collect.Papers.Collect", "query is required", nil) + } + + if cfg.Dispatcher != nil { + cfg.Dispatcher.EmitStart(p.Name(), fmt.Sprintf("Starting paper collection for %q", p.Query)) + } + + if cfg.DryRun { + if cfg.Dispatcher != nil { + cfg.Dispatcher.EmitProgress(p.Name(), fmt.Sprintf("[dry-run] Would search papers for %q", p.Query), nil) + } + return result, nil + } + + switch p.Source { + case PaperSourceIACR: + return p.collectIACR(ctx, cfg) + case PaperSourceArXiv: + return p.collectArXiv(ctx, cfg) + case PaperSourceAll: + iacrResult, iacrErr := p.collectIACR(ctx, cfg) + arxivResult, arxivErr := p.collectArXiv(ctx, cfg) + + if iacrErr != nil && arxivErr != nil { + return result, core.E("collect.Papers.Collect", "all sources failed", iacrErr) + } + + merged := MergeResults(p.Name(), iacrResult, arxivResult) + if iacrErr != nil { + merged.Errors++ + } + if arxivErr != nil { + merged.Errors++ + } + + if cfg.Dispatcher != nil { + cfg.Dispatcher.EmitComplete(p.Name(), fmt.Sprintf("Collected %d papers", merged.Items), merged) + } + + return merged, nil + default: + return result, core.E("collect.Papers.Collect", + fmt.Sprintf("unknown source: %s (use iacr, arxiv, or all)", p.Source), nil) + } +} + +// collectIACR fetches papers from the IACR ePrint archive. +func (p *PapersCollector) collectIACR(ctx context.Context, cfg *Config) (*Result, error) { + result := &Result{Source: "papers:iacr"} + + if cfg.Limiter != nil { + if err := cfg.Limiter.Wait(ctx, "iacr"); err != nil { + return result, err + } + } + + searchURL := fmt.Sprintf("https://eprint.iacr.org/search?q=%s", url.QueryEscape(p.Query)) + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, searchURL, nil) + if err != nil { + return result, core.E("collect.Papers.collectIACR", "failed to create request", err) + } + req.Header.Set("User-Agent", "CoreCollector/1.0") + + resp, err := httpClient.Do(req) + if err != nil { + return result, core.E("collect.Papers.collectIACR", "request failed", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + return result, core.E("collect.Papers.collectIACR", + fmt.Sprintf("unexpected status code: %d", resp.StatusCode), nil) + } + + doc, err := html.Parse(resp.Body) + if err != nil { + return result, core.E("collect.Papers.collectIACR", "failed to parse HTML", err) + } + + papers := extractIACRPapers(doc) + + baseDir := filepath.Join(cfg.OutputDir, "papers", "iacr") + if err := cfg.Output.EnsureDir(baseDir); err != nil { + return result, core.E("collect.Papers.collectIACR", "failed to create output directory", err) + } + + for _, ppr := range papers { + filePath := filepath.Join(baseDir, ppr.ID+".md") + content := formatPaperMarkdown(ppr) + + if err := cfg.Output.Write(filePath, content); err != nil { + result.Errors++ + continue + } + + result.Items++ + result.Files = append(result.Files, filePath) + + if cfg.Dispatcher != nil { + cfg.Dispatcher.EmitItem(p.Name(), fmt.Sprintf("Paper: %s", ppr.Title), nil) + } + } + + return result, nil +} + +// arxivFeed represents the Atom feed returned by the arXiv API. +type arxivFeed struct { + XMLName xml.Name `xml:"feed"` + Entries []arxivEntry `xml:"entry"` +} + +type arxivEntry struct { + ID string `xml:"id"` + Title string `xml:"title"` + Summary string `xml:"summary"` + Published string `xml:"published"` + Authors []arxivAuthor `xml:"author"` + Links []arxivLink `xml:"link"` +} + +type arxivAuthor struct { + Name string `xml:"name"` +} + +type arxivLink struct { + Href string `xml:"href,attr"` + Rel string `xml:"rel,attr"` + Type string `xml:"type,attr"` +} + +// collectArXiv fetches papers from the arXiv API. +func (p *PapersCollector) collectArXiv(ctx context.Context, cfg *Config) (*Result, error) { + result := &Result{Source: "papers:arxiv"} + + if cfg.Limiter != nil { + if err := cfg.Limiter.Wait(ctx, "arxiv"); err != nil { + return result, err + } + } + + query := url.QueryEscape(p.Query) + if p.Category != "" { + query = fmt.Sprintf("cat:%s+AND+%s", url.QueryEscape(p.Category), query) + } + + searchURL := fmt.Sprintf("https://export.arxiv.org/api/query?search_query=%s&max_results=50", query) + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, searchURL, nil) + if err != nil { + return result, core.E("collect.Papers.collectArXiv", "failed to create request", err) + } + req.Header.Set("User-Agent", "CoreCollector/1.0") + + resp, err := httpClient.Do(req) + if err != nil { + return result, core.E("collect.Papers.collectArXiv", "request failed", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + return result, core.E("collect.Papers.collectArXiv", + fmt.Sprintf("unexpected status code: %d", resp.StatusCode), nil) + } + + var feed arxivFeed + if err := xml.NewDecoder(resp.Body).Decode(&feed); err != nil { + return result, core.E("collect.Papers.collectArXiv", "failed to parse XML", err) + } + + baseDir := filepath.Join(cfg.OutputDir, "papers", "arxiv") + if err := cfg.Output.EnsureDir(baseDir); err != nil { + return result, core.E("collect.Papers.collectArXiv", "failed to create output directory", err) + } + + for _, entry := range feed.Entries { + ppr := arxivEntryToPaper(entry) + + filePath := filepath.Join(baseDir, ppr.ID+".md") + content := formatPaperMarkdown(ppr) + + if err := cfg.Output.Write(filePath, content); err != nil { + result.Errors++ + continue + } + + result.Items++ + result.Files = append(result.Files, filePath) + + if cfg.Dispatcher != nil { + cfg.Dispatcher.EmitItem(p.Name(), fmt.Sprintf("Paper: %s", ppr.Title), nil) + } + } + + return result, nil +} + +// arxivEntryToPaper converts an arXiv Atom entry to a paper. +func arxivEntryToPaper(entry arxivEntry) paper { + authors := make([]string, len(entry.Authors)) + for i, a := range entry.Authors { + authors[i] = a.Name + } + + // Extract the arXiv ID from the URL + id := entry.ID + if idx := strings.LastIndex(id, "/abs/"); idx != -1 { + id = id[idx+5:] + } + // Replace characters that are not valid in file names + id = strings.ReplaceAll(id, "/", "-") + id = strings.ReplaceAll(id, ":", "-") + + paperURL := entry.ID + for _, link := range entry.Links { + if link.Rel == "alternate" { + paperURL = link.Href + break + } + } + + return paper{ + ID: id, + Title: strings.TrimSpace(entry.Title), + Authors: authors, + Abstract: strings.TrimSpace(entry.Summary), + Date: entry.Published, + URL: paperURL, + Source: "arxiv", + } +} + +// extractIACRPapers extracts paper metadata from an IACR search results page. +func extractIACRPapers(doc *html.Node) []paper { + var papers []paper + var walk func(*html.Node) + + walk = func(n *html.Node) { + if n.Type == html.ElementNode && n.Data == "div" { + for _, attr := range n.Attr { + if attr.Key == "class" && strings.Contains(attr.Val, "paperentry") { + ppr := parseIACREntry(n) + if ppr.Title != "" { + papers = append(papers, ppr) + } + } + } + } + for c := n.FirstChild; c != nil; c = c.NextSibling { + walk(c) + } + } + + walk(doc) + return papers +} + +// parseIACREntry extracts paper data from an IACR paper entry div. +func parseIACREntry(node *html.Node) paper { + ppr := paper{Source: "iacr"} + var walk func(*html.Node) + + walk = func(n *html.Node) { + if n.Type == html.ElementNode { + switch n.Data { + case "a": + for _, attr := range n.Attr { + if attr.Key == "href" && strings.Contains(attr.Val, "/eprint/") { + ppr.URL = "https://eprint.iacr.org" + attr.Val + // Extract ID from URL + parts := strings.Split(attr.Val, "/") + if len(parts) >= 2 { + ppr.ID = parts[len(parts)-2] + "-" + parts[len(parts)-1] + } + } + } + if ppr.Title == "" { + ppr.Title = strings.TrimSpace(extractText(n)) + } + case "span": + for _, attr := range n.Attr { + if attr.Key == "class" { + switch { + case strings.Contains(attr.Val, "author"): + author := strings.TrimSpace(extractText(n)) + if author != "" { + ppr.Authors = append(ppr.Authors, author) + } + case strings.Contains(attr.Val, "date"): + ppr.Date = strings.TrimSpace(extractText(n)) + } + } + } + case "p": + for _, attr := range n.Attr { + if attr.Key == "class" && strings.Contains(attr.Val, "abstract") { + ppr.Abstract = strings.TrimSpace(extractText(n)) + } + } + } + } + for c := n.FirstChild; c != nil; c = c.NextSibling { + walk(c) + } + } + + walk(node) + return ppr +} + +// formatPaperMarkdown formats a paper as markdown. +func formatPaperMarkdown(ppr paper) string { + var b strings.Builder + fmt.Fprintf(&b, "# %s\n\n", ppr.Title) + + if len(ppr.Authors) > 0 { + fmt.Fprintf(&b, "- **Authors:** %s\n", strings.Join(ppr.Authors, ", ")) + } + if ppr.Date != "" { + fmt.Fprintf(&b, "- **Published:** %s\n", ppr.Date) + } + if ppr.URL != "" { + fmt.Fprintf(&b, "- **URL:** %s\n", ppr.URL) + } + if ppr.Source != "" { + fmt.Fprintf(&b, "- **Source:** %s\n", ppr.Source) + } + + if ppr.Abstract != "" { + fmt.Fprintf(&b, "\n## Abstract\n\n%s\n", ppr.Abstract) + } + + return b.String() +} + +// FormatPaperMarkdown is exported for testing. +func FormatPaperMarkdown(title string, authors []string, date, paperURL, source, abstract string) string { + return formatPaperMarkdown(paper{ + Title: title, + Authors: authors, + Date: date, + URL: paperURL, + Source: source, + Abstract: abstract, + }) +} diff --git a/collect/papers_test.go b/collect/papers_test.go new file mode 100644 index 0000000..7a89e92 --- /dev/null +++ b/collect/papers_test.go @@ -0,0 +1,108 @@ +package collect + +import ( + "context" + "testing" + + "forge.lthn.ai/core/go/pkg/io" + "github.com/stretchr/testify/assert" +) + +func TestPapersCollector_Name_Good(t *testing.T) { + p := &PapersCollector{Source: PaperSourceIACR} + assert.Equal(t, "papers:iacr", p.Name()) +} + +func TestPapersCollector_Name_Good_ArXiv(t *testing.T) { + p := &PapersCollector{Source: PaperSourceArXiv} + assert.Equal(t, "papers:arxiv", p.Name()) +} + +func TestPapersCollector_Name_Good_All(t *testing.T) { + p := &PapersCollector{Source: PaperSourceAll} + assert.Equal(t, "papers:all", p.Name()) +} + +func TestPapersCollector_Collect_Bad_NoQuery(t *testing.T) { + m := io.NewMockMedium() + cfg := NewConfigWithMedium(m, "/output") + + p := &PapersCollector{Source: PaperSourceIACR} + _, err := p.Collect(context.Background(), cfg) + assert.Error(t, err) +} + +func TestPapersCollector_Collect_Bad_UnknownSource(t *testing.T) { + m := io.NewMockMedium() + cfg := NewConfigWithMedium(m, "/output") + + p := &PapersCollector{Source: "unknown", Query: "test"} + _, err := p.Collect(context.Background(), cfg) + assert.Error(t, err) +} + +func TestPapersCollector_Collect_Good_DryRun(t *testing.T) { + m := io.NewMockMedium() + cfg := NewConfigWithMedium(m, "/output") + cfg.DryRun = true + + p := &PapersCollector{Source: PaperSourceAll, Query: "cryptography"} + result, err := p.Collect(context.Background(), cfg) + + assert.NoError(t, err) + assert.Equal(t, 0, result.Items) +} + +func TestFormatPaperMarkdown_Good(t *testing.T) { + md := FormatPaperMarkdown( + "Zero-Knowledge Proofs Revisited", + []string{"Alice", "Bob"}, + "2025-01-15", + "https://eprint.iacr.org/2025/001", + "iacr", + "We present a new construction for zero-knowledge proofs.", + ) + + assert.Contains(t, md, "# Zero-Knowledge Proofs Revisited") + assert.Contains(t, md, "**Authors:** Alice, Bob") + assert.Contains(t, md, "**Published:** 2025-01-15") + assert.Contains(t, md, "**URL:** https://eprint.iacr.org/2025/001") + assert.Contains(t, md, "**Source:** iacr") + assert.Contains(t, md, "## Abstract") + assert.Contains(t, md, "zero-knowledge proofs") +} + +func TestFormatPaperMarkdown_Good_Minimal(t *testing.T) { + md := FormatPaperMarkdown("Title Only", nil, "", "", "", "") + + assert.Contains(t, md, "# Title Only") + assert.NotContains(t, md, "**Authors:**") + assert.NotContains(t, md, "## Abstract") +} + +func TestArxivEntryToPaper_Good(t *testing.T) { + entry := arxivEntry{ + ID: "http://arxiv.org/abs/2501.12345v1", + Title: " A Great Paper ", + Summary: " This paper presents... ", + Published: "2025-01-15T00:00:00Z", + Authors: []arxivAuthor{ + {Name: "Alice"}, + {Name: "Bob"}, + }, + Links: []arxivLink{ + {Href: "http://arxiv.org/abs/2501.12345v1", Rel: "alternate"}, + {Href: "http://arxiv.org/pdf/2501.12345v1", Rel: "related", Type: "application/pdf"}, + }, + } + + ppr := arxivEntryToPaper(entry) + + assert.Equal(t, "2501.12345v1", ppr.ID) + assert.Equal(t, "A Great Paper", ppr.Title) + assert.Equal(t, "This paper presents...", ppr.Abstract) + assert.Equal(t, "2025-01-15T00:00:00Z", ppr.Date) + assert.Equal(t, []string{"Alice", "Bob"}, ppr.Authors) + assert.Equal(t, "http://arxiv.org/abs/2501.12345v1", ppr.URL) + assert.Equal(t, "arxiv", ppr.Source) +} diff --git a/collect/process.go b/collect/process.go new file mode 100644 index 0000000..b907bd9 --- /dev/null +++ b/collect/process.go @@ -0,0 +1,345 @@ +package collect + +import ( + "context" + "encoding/json" + "fmt" + "path/filepath" + "sort" + "strings" + + core "forge.lthn.ai/core/go/pkg/framework/core" + "golang.org/x/net/html" +) + +// Processor converts collected data to clean markdown. +type Processor struct { + // Source identifies the data source directory to process. + Source string + + // Dir is the directory containing files to process. + Dir string +} + +// Name returns the processor name. +func (p *Processor) Name() string { + return fmt.Sprintf("process:%s", p.Source) +} + +// Process reads files from the source directory, converts HTML or JSON +// to clean markdown, and writes the results to the output directory. +func (p *Processor) Process(ctx context.Context, cfg *Config) (*Result, error) { + result := &Result{Source: p.Name()} + + if p.Dir == "" { + return result, core.E("collect.Processor.Process", "directory is required", nil) + } + + if cfg.Dispatcher != nil { + cfg.Dispatcher.EmitStart(p.Name(), fmt.Sprintf("Processing files in %s", p.Dir)) + } + + if cfg.DryRun { + if cfg.Dispatcher != nil { + cfg.Dispatcher.EmitProgress(p.Name(), fmt.Sprintf("[dry-run] Would process files in %s", p.Dir), nil) + } + return result, nil + } + + entries, err := cfg.Output.List(p.Dir) + if err != nil { + return result, core.E("collect.Processor.Process", "failed to list directory", err) + } + + outputDir := filepath.Join(cfg.OutputDir, "processed", p.Source) + if err := cfg.Output.EnsureDir(outputDir); err != nil { + return result, core.E("collect.Processor.Process", "failed to create output directory", err) + } + + for _, entry := range entries { + if ctx.Err() != nil { + return result, core.E("collect.Processor.Process", "context cancelled", ctx.Err()) + } + + if entry.IsDir() { + continue + } + + name := entry.Name() + srcPath := filepath.Join(p.Dir, name) + + content, err := cfg.Output.Read(srcPath) + if err != nil { + result.Errors++ + continue + } + + var processed string + ext := strings.ToLower(filepath.Ext(name)) + + switch ext { + case ".html", ".htm": + processed, err = htmlToMarkdown(content) + if err != nil { + result.Errors++ + if cfg.Dispatcher != nil { + cfg.Dispatcher.EmitError(p.Name(), fmt.Sprintf("Failed to convert %s: %v", name, err), nil) + } + continue + } + case ".json": + processed, err = jsonToMarkdown(content) + if err != nil { + result.Errors++ + if cfg.Dispatcher != nil { + cfg.Dispatcher.EmitError(p.Name(), fmt.Sprintf("Failed to convert %s: %v", name, err), nil) + } + continue + } + case ".md": + // Already markdown, just clean up + processed = strings.TrimSpace(content) + default: + result.Skipped++ + continue + } + + // Write with .md extension + outName := strings.TrimSuffix(name, ext) + ".md" + outPath := filepath.Join(outputDir, outName) + + if err := cfg.Output.Write(outPath, processed); err != nil { + result.Errors++ + continue + } + + result.Items++ + result.Files = append(result.Files, outPath) + + if cfg.Dispatcher != nil { + cfg.Dispatcher.EmitItem(p.Name(), fmt.Sprintf("Processed: %s", name), nil) + } + } + + if cfg.Dispatcher != nil { + cfg.Dispatcher.EmitComplete(p.Name(), fmt.Sprintf("Processed %d files", result.Items), result) + } + + return result, nil +} + +// htmlToMarkdown converts HTML content to clean markdown. +func htmlToMarkdown(content string) (string, error) { + doc, err := html.Parse(strings.NewReader(content)) + if err != nil { + return "", core.E("collect.htmlToMarkdown", "failed to parse HTML", err) + } + + var b strings.Builder + nodeToMarkdown(&b, doc, 0) + return strings.TrimSpace(b.String()), nil +} + +// nodeToMarkdown recursively converts an HTML node tree to markdown. +func nodeToMarkdown(b *strings.Builder, n *html.Node, depth int) { + switch n.Type { + case html.TextNode: + text := n.Data + if strings.TrimSpace(text) != "" { + b.WriteString(text) + } + case html.ElementNode: + switch n.Data { + case "h1": + b.WriteString("\n# ") + writeChildrenText(b, n) + b.WriteString("\n\n") + return + case "h2": + b.WriteString("\n## ") + writeChildrenText(b, n) + b.WriteString("\n\n") + return + case "h3": + b.WriteString("\n### ") + writeChildrenText(b, n) + b.WriteString("\n\n") + return + case "h4": + b.WriteString("\n#### ") + writeChildrenText(b, n) + b.WriteString("\n\n") + return + case "h5": + b.WriteString("\n##### ") + writeChildrenText(b, n) + b.WriteString("\n\n") + return + case "h6": + b.WriteString("\n###### ") + writeChildrenText(b, n) + b.WriteString("\n\n") + return + case "p": + b.WriteString("\n") + for c := n.FirstChild; c != nil; c = c.NextSibling { + nodeToMarkdown(b, c, depth) + } + b.WriteString("\n") + return + case "br": + b.WriteString("\n") + return + case "strong", "b": + b.WriteString("**") + writeChildrenText(b, n) + b.WriteString("**") + return + case "em", "i": + b.WriteString("*") + writeChildrenText(b, n) + b.WriteString("*") + return + case "code": + b.WriteString("`") + writeChildrenText(b, n) + b.WriteString("`") + return + case "pre": + b.WriteString("\n```\n") + writeChildrenText(b, n) + b.WriteString("\n```\n") + return + case "a": + var href string + for _, attr := range n.Attr { + if attr.Key == "href" { + href = attr.Val + } + } + text := getChildrenText(n) + if href != "" { + fmt.Fprintf(b, "[%s](%s)", text, href) + } else { + b.WriteString(text) + } + return + case "ul": + b.WriteString("\n") + case "ol": + b.WriteString("\n") + counter := 1 + for c := n.FirstChild; c != nil; c = c.NextSibling { + if c.Type == html.ElementNode && c.Data == "li" { + fmt.Fprintf(b, "%d. ", counter) + for gc := c.FirstChild; gc != nil; gc = gc.NextSibling { + nodeToMarkdown(b, gc, depth+1) + } + b.WriteString("\n") + counter++ + } + } + return + case "li": + b.WriteString("- ") + for c := n.FirstChild; c != nil; c = c.NextSibling { + nodeToMarkdown(b, c, depth+1) + } + b.WriteString("\n") + return + case "blockquote": + b.WriteString("\n> ") + text := getChildrenText(n) + b.WriteString(strings.ReplaceAll(text, "\n", "\n> ")) + b.WriteString("\n") + return + case "hr": + b.WriteString("\n---\n") + return + case "script", "style", "head": + return + } + } + + for c := n.FirstChild; c != nil; c = c.NextSibling { + nodeToMarkdown(b, c, depth) + } +} + +// writeChildrenText writes the text content of all children. +func writeChildrenText(b *strings.Builder, n *html.Node) { + b.WriteString(getChildrenText(n)) +} + +// getChildrenText returns the concatenated text content of all children. +func getChildrenText(n *html.Node) string { + var b strings.Builder + for c := n.FirstChild; c != nil; c = c.NextSibling { + if c.Type == html.TextNode { + b.WriteString(c.Data) + } else { + b.WriteString(getChildrenText(c)) + } + } + return b.String() +} + +// jsonToMarkdown converts JSON content to a formatted markdown document. +func jsonToMarkdown(content string) (string, error) { + var data any + if err := json.Unmarshal([]byte(content), &data); err != nil { + return "", core.E("collect.jsonToMarkdown", "failed to parse JSON", err) + } + + var b strings.Builder + b.WriteString("# Data\n\n") + jsonValueToMarkdown(&b, data, 0) + return strings.TrimSpace(b.String()), nil +} + +// jsonValueToMarkdown recursively formats a JSON value as markdown. +func jsonValueToMarkdown(b *strings.Builder, data any, depth int) { + switch v := data.(type) { + case map[string]any: + keys := make([]string, 0, len(v)) + for key := range v { + keys = append(keys, key) + } + sort.Strings(keys) + for _, key := range keys { + val := v[key] + indent := strings.Repeat(" ", depth) + switch child := val.(type) { + case map[string]any, []any: + fmt.Fprintf(b, "%s- **%s:**\n", indent, key) + jsonValueToMarkdown(b, child, depth+1) + default: + fmt.Fprintf(b, "%s- **%s:** %v\n", indent, key, val) + } + } + case []any: + for i, item := range v { + indent := strings.Repeat(" ", depth) + switch child := item.(type) { + case map[string]any, []any: + fmt.Fprintf(b, "%s- Item %d:\n", indent, i+1) + jsonValueToMarkdown(b, child, depth+1) + default: + fmt.Fprintf(b, "%s- %v\n", indent, item) + } + } + default: + indent := strings.Repeat(" ", depth) + fmt.Fprintf(b, "%s%v\n", indent, data) + } +} + +// HTMLToMarkdown is exported for testing. +func HTMLToMarkdown(content string) (string, error) { + return htmlToMarkdown(content) +} + +// JSONToMarkdown is exported for testing. +func JSONToMarkdown(content string) (string, error) { + return jsonToMarkdown(content) +} diff --git a/collect/process_test.go b/collect/process_test.go new file mode 100644 index 0000000..7b0b887 --- /dev/null +++ b/collect/process_test.go @@ -0,0 +1,201 @@ +package collect + +import ( + "context" + "testing" + + "forge.lthn.ai/core/go/pkg/io" + "github.com/stretchr/testify/assert" +) + +func TestProcessor_Name_Good(t *testing.T) { + p := &Processor{Source: "github"} + assert.Equal(t, "process:github", p.Name()) +} + +func TestProcessor_Process_Bad_NoDir(t *testing.T) { + m := io.NewMockMedium() + cfg := NewConfigWithMedium(m, "/output") + + p := &Processor{Source: "test"} + _, err := p.Process(context.Background(), cfg) + assert.Error(t, err) +} + +func TestProcessor_Process_Good_DryRun(t *testing.T) { + m := io.NewMockMedium() + cfg := NewConfigWithMedium(m, "/output") + cfg.DryRun = true + + p := &Processor{Source: "test", Dir: "/input"} + result, err := p.Process(context.Background(), cfg) + + assert.NoError(t, err) + assert.Equal(t, 0, result.Items) +} + +func TestProcessor_Process_Good_HTMLFiles(t *testing.T) { + m := io.NewMockMedium() + m.Dirs["/input"] = true + m.Files["/input/page.html"] = `

Hello

World

` + + cfg := NewConfigWithMedium(m, "/output") + cfg.Limiter = nil + + p := &Processor{Source: "test", Dir: "/input"} + result, err := p.Process(context.Background(), cfg) + + assert.NoError(t, err) + assert.Equal(t, 1, result.Items) + assert.Len(t, result.Files, 1) + + content, err := m.Read("/output/processed/test/page.md") + assert.NoError(t, err) + assert.Contains(t, content, "# Hello") + assert.Contains(t, content, "World") +} + +func TestProcessor_Process_Good_JSONFiles(t *testing.T) { + m := io.NewMockMedium() + m.Dirs["/input"] = true + m.Files["/input/data.json"] = `{"name": "Bitcoin", "price": 42000}` + + cfg := NewConfigWithMedium(m, "/output") + cfg.Limiter = nil + + p := &Processor{Source: "market", Dir: "/input"} + result, err := p.Process(context.Background(), cfg) + + assert.NoError(t, err) + assert.Equal(t, 1, result.Items) + + content, err := m.Read("/output/processed/market/data.md") + assert.NoError(t, err) + assert.Contains(t, content, "# Data") + assert.Contains(t, content, "Bitcoin") +} + +func TestProcessor_Process_Good_MarkdownPassthrough(t *testing.T) { + m := io.NewMockMedium() + m.Dirs["/input"] = true + m.Files["/input/readme.md"] = "# Already Markdown\n\nThis is already formatted." + + cfg := NewConfigWithMedium(m, "/output") + cfg.Limiter = nil + + p := &Processor{Source: "docs", Dir: "/input"} + result, err := p.Process(context.Background(), cfg) + + assert.NoError(t, err) + assert.Equal(t, 1, result.Items) + + content, err := m.Read("/output/processed/docs/readme.md") + assert.NoError(t, err) + assert.Contains(t, content, "# Already Markdown") +} + +func TestProcessor_Process_Good_SkipUnknownTypes(t *testing.T) { + m := io.NewMockMedium() + m.Dirs["/input"] = true + m.Files["/input/image.png"] = "binary data" + m.Files["/input/doc.html"] = "

Heading

" + + cfg := NewConfigWithMedium(m, "/output") + cfg.Limiter = nil + + p := &Processor{Source: "mixed", Dir: "/input"} + result, err := p.Process(context.Background(), cfg) + + assert.NoError(t, err) + assert.Equal(t, 1, result.Items) // Only the HTML file + assert.Equal(t, 1, result.Skipped) // The PNG file +} + +func TestHTMLToMarkdown_Good(t *testing.T) { + tests := []struct { + name string + input string + contains []string + }{ + { + name: "heading", + input: "

Title

", + contains: []string{"# Title"}, + }, + { + name: "paragraph", + input: "

Hello world

", + contains: []string{"Hello world"}, + }, + { + name: "bold", + input: "

bold text

", + contains: []string{"**bold text**"}, + }, + { + name: "italic", + input: "

italic text

", + contains: []string{"*italic text*"}, + }, + { + name: "code", + input: "

code

", + contains: []string{"`code`"}, + }, + { + name: "link", + input: `

Example

`, + contains: []string{"[Example](https://example.com)"}, + }, + { + name: "nested headings", + input: "

Section

Subsection

", + contains: []string{"## Section", "### Subsection"}, + }, + { + name: "pre block", + input: "
func main() {}
", + contains: []string{"```", "func main() {}"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := HTMLToMarkdown(tt.input) + assert.NoError(t, err) + for _, s := range tt.contains { + assert.Contains(t, result, s) + } + }) + } +} + +func TestHTMLToMarkdown_Good_StripsScripts(t *testing.T) { + input := `

Clean

` + result, err := HTMLToMarkdown(input) + assert.NoError(t, err) + assert.Contains(t, result, "Clean") + assert.NotContains(t, result, "alert") + assert.NotContains(t, result, "script") +} + +func TestJSONToMarkdown_Good(t *testing.T) { + input := `{"name": "test", "count": 42}` + result, err := JSONToMarkdown(input) + assert.NoError(t, err) + assert.Contains(t, result, "# Data") + assert.Contains(t, result, "test") + assert.Contains(t, result, "42") +} + +func TestJSONToMarkdown_Good_Array(t *testing.T) { + input := `[{"id": 1}, {"id": 2}]` + result, err := JSONToMarkdown(input) + assert.NoError(t, err) + assert.Contains(t, result, "# Data") +} + +func TestJSONToMarkdown_Bad_InvalidJSON(t *testing.T) { + _, err := JSONToMarkdown("not json") + assert.Error(t, err) +} diff --git a/collect/ratelimit.go b/collect/ratelimit.go new file mode 100644 index 0000000..469d493 --- /dev/null +++ b/collect/ratelimit.go @@ -0,0 +1,130 @@ +package collect + +import ( + "context" + "fmt" + "os/exec" + "strconv" + "strings" + "sync" + "time" + + core "forge.lthn.ai/core/go/pkg/framework/core" +) + +// RateLimiter tracks per-source rate limiting to avoid overwhelming APIs. +type RateLimiter struct { + mu sync.Mutex + delays map[string]time.Duration + last map[string]time.Time +} + +// Default rate limit delays per source. +var defaultDelays = map[string]time.Duration{ + "github": 500 * time.Millisecond, + "bitcointalk": 2 * time.Second, + "coingecko": 1500 * time.Millisecond, + "iacr": 1 * time.Second, + "arxiv": 1 * time.Second, +} + +// NewRateLimiter creates a limiter with default delays. +func NewRateLimiter() *RateLimiter { + delays := make(map[string]time.Duration, len(defaultDelays)) + for k, v := range defaultDelays { + delays[k] = v + } + return &RateLimiter{ + delays: delays, + last: make(map[string]time.Time), + } +} + +// Wait blocks until the rate limit allows the next request for the given source. +// It respects context cancellation. +func (r *RateLimiter) Wait(ctx context.Context, source string) error { + r.mu.Lock() + delay, ok := r.delays[source] + if !ok { + delay = 500 * time.Millisecond + } + lastTime := r.last[source] + + elapsed := time.Since(lastTime) + if elapsed >= delay { + // Enough time has passed — claim the slot immediately. + r.last[source] = time.Now() + r.mu.Unlock() + return nil + } + + remaining := delay - elapsed + r.mu.Unlock() + + // Wait outside the lock, then reclaim. + select { + case <-ctx.Done(): + return core.E("collect.RateLimiter.Wait", "context cancelled", ctx.Err()) + case <-time.After(remaining): + } + + r.mu.Lock() + r.last[source] = time.Now() + r.mu.Unlock() + + return nil +} + +// SetDelay sets the delay for a source. +func (r *RateLimiter) SetDelay(source string, d time.Duration) { + r.mu.Lock() + defer r.mu.Unlock() + r.delays[source] = d +} + +// GetDelay returns the delay configured for a source. +func (r *RateLimiter) GetDelay(source string) time.Duration { + r.mu.Lock() + defer r.mu.Unlock() + if d, ok := r.delays[source]; ok { + return d + } + return 500 * time.Millisecond +} + +// CheckGitHubRateLimit checks GitHub API rate limit status via gh api. +// Returns used and limit counts. Auto-pauses at 75% usage by increasing +// the GitHub rate limit delay. +func (r *RateLimiter) CheckGitHubRateLimit() (used, limit int, err error) { + cmd := exec.Command("gh", "api", "rate_limit", "--jq", ".rate | \"\\(.used) \\(.limit)\"") + out, err := cmd.Output() + if err != nil { + return 0, 0, core.E("collect.RateLimiter.CheckGitHubRateLimit", "failed to check rate limit", err) + } + + parts := strings.Fields(strings.TrimSpace(string(out))) + if len(parts) != 2 { + return 0, 0, core.E("collect.RateLimiter.CheckGitHubRateLimit", + fmt.Sprintf("unexpected output format: %q", string(out)), nil) + } + + used, err = strconv.Atoi(parts[0]) + if err != nil { + return 0, 0, core.E("collect.RateLimiter.CheckGitHubRateLimit", "failed to parse used count", err) + } + + limit, err = strconv.Atoi(parts[1]) + if err != nil { + return 0, 0, core.E("collect.RateLimiter.CheckGitHubRateLimit", "failed to parse limit count", err) + } + + // Auto-pause at 75% usage + if limit > 0 { + usage := float64(used) / float64(limit) + if usage >= 0.75 { + r.SetDelay("github", 5*time.Second) + } + } + + return used, limit, nil +} diff --git a/collect/ratelimit_test.go b/collect/ratelimit_test.go new file mode 100644 index 0000000..778d36d --- /dev/null +++ b/collect/ratelimit_test.go @@ -0,0 +1,84 @@ +package collect + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestRateLimiter_Wait_Good(t *testing.T) { + rl := NewRateLimiter() + rl.SetDelay("test", 50*time.Millisecond) + + ctx := context.Background() + + // First call should return immediately + start := time.Now() + err := rl.Wait(ctx, "test") + assert.NoError(t, err) + assert.Less(t, time.Since(start), 50*time.Millisecond) + + // Second call should wait at least the delay + start = time.Now() + err = rl.Wait(ctx, "test") + assert.NoError(t, err) + assert.GreaterOrEqual(t, time.Since(start), 40*time.Millisecond) // allow small timing variance +} + +func TestRateLimiter_Wait_Bad_ContextCancelled(t *testing.T) { + rl := NewRateLimiter() + rl.SetDelay("test", 5*time.Second) + + ctx := context.Background() + + // First call to set the last time + err := rl.Wait(ctx, "test") + assert.NoError(t, err) + + // Cancel context before second call + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + err = rl.Wait(ctx, "test") + assert.Error(t, err) +} + +func TestRateLimiter_SetDelay_Good(t *testing.T) { + rl := NewRateLimiter() + rl.SetDelay("custom", 3*time.Second) + assert.Equal(t, 3*time.Second, rl.GetDelay("custom")) +} + +func TestRateLimiter_GetDelay_Good_Defaults(t *testing.T) { + rl := NewRateLimiter() + + assert.Equal(t, 500*time.Millisecond, rl.GetDelay("github")) + assert.Equal(t, 2*time.Second, rl.GetDelay("bitcointalk")) + assert.Equal(t, 1500*time.Millisecond, rl.GetDelay("coingecko")) + assert.Equal(t, 1*time.Second, rl.GetDelay("iacr")) +} + +func TestRateLimiter_GetDelay_Good_UnknownSource(t *testing.T) { + rl := NewRateLimiter() + // Unknown sources should get the default 500ms delay + assert.Equal(t, 500*time.Millisecond, rl.GetDelay("unknown")) +} + +func TestRateLimiter_Wait_Good_UnknownSource(t *testing.T) { + rl := NewRateLimiter() + ctx := context.Background() + + // Unknown source should use default delay of 500ms + err := rl.Wait(ctx, "unknown-source") + assert.NoError(t, err) +} + +func TestNewRateLimiter_Good(t *testing.T) { + rl := NewRateLimiter() + assert.NotNil(t, rl) + assert.NotNil(t, rl.delays) + assert.NotNil(t, rl.last) + assert.Len(t, rl.delays, len(defaultDelays)) +} diff --git a/collect/state.go b/collect/state.go new file mode 100644 index 0000000..14b38a9 --- /dev/null +++ b/collect/state.go @@ -0,0 +1,113 @@ +package collect + +import ( + "encoding/json" + "sync" + "time" + + core "forge.lthn.ai/core/go/pkg/framework/core" + "forge.lthn.ai/core/go/pkg/io" +) + +// State tracks collection progress for incremental runs. +// It persists entries to disk so that subsequent runs can resume +// where they left off. +type State struct { + mu sync.Mutex + medium io.Medium + path string + entries map[string]*StateEntry +} + +// StateEntry tracks state for one source. +type StateEntry struct { + // Source identifies the collector. + Source string `json:"source"` + + // LastRun is the timestamp of the last successful run. + LastRun time.Time `json:"last_run"` + + // LastID is an opaque identifier for the last item processed. + LastID string `json:"last_id,omitempty"` + + // Items is the total number of items collected so far. + Items int `json:"items"` + + // Cursor is an opaque pagination cursor for resumption. + Cursor string `json:"cursor,omitempty"` +} + +// NewState creates a state tracker that persists to the given path +// using the provided storage medium. +func NewState(m io.Medium, path string) *State { + return &State{ + medium: m, + path: path, + entries: make(map[string]*StateEntry), + } +} + +// Load reads state from disk. If the file does not exist, the state +// is initialised as empty without error. +func (s *State) Load() error { + s.mu.Lock() + defer s.mu.Unlock() + + if !s.medium.IsFile(s.path) { + return nil + } + + data, err := s.medium.Read(s.path) + if err != nil { + return core.E("collect.State.Load", "failed to read state file", err) + } + + var entries map[string]*StateEntry + if err := json.Unmarshal([]byte(data), &entries); err != nil { + return core.E("collect.State.Load", "failed to parse state file", err) + } + + if entries == nil { + entries = make(map[string]*StateEntry) + } + s.entries = entries + return nil +} + +// Save writes state to disk. +func (s *State) Save() error { + s.mu.Lock() + defer s.mu.Unlock() + + data, err := json.MarshalIndent(s.entries, "", " ") + if err != nil { + return core.E("collect.State.Save", "failed to marshal state", err) + } + + if err := s.medium.Write(s.path, string(data)); err != nil { + return core.E("collect.State.Save", "failed to write state file", err) + } + + return nil +} + +// Get returns a copy of the state for a source. The second return value +// indicates whether the entry was found. +func (s *State) Get(source string) (*StateEntry, bool) { + s.mu.Lock() + defer s.mu.Unlock() + entry, ok := s.entries[source] + if !ok { + return nil, false + } + // Return a copy to avoid callers mutating internal state. + cp := *entry + return &cp, true +} + +// Set updates state for a source. +func (s *State) Set(source string, entry *StateEntry) { + s.mu.Lock() + defer s.mu.Unlock() + s.entries[source] = entry +} diff --git a/collect/state_test.go b/collect/state_test.go new file mode 100644 index 0000000..90b48bd --- /dev/null +++ b/collect/state_test.go @@ -0,0 +1,144 @@ +package collect + +import ( + "testing" + "time" + + "forge.lthn.ai/core/go/pkg/io" + "github.com/stretchr/testify/assert" +) + +func TestState_SetGet_Good(t *testing.T) { + m := io.NewMockMedium() + s := NewState(m, "/state.json") + + entry := &StateEntry{ + Source: "github:test", + LastRun: time.Now(), + Items: 42, + LastID: "abc123", + Cursor: "cursor-xyz", + } + + s.Set("github:test", entry) + + got, ok := s.Get("github:test") + assert.True(t, ok) + assert.Equal(t, entry.Source, got.Source) + assert.Equal(t, entry.Items, got.Items) + assert.Equal(t, entry.LastID, got.LastID) + assert.Equal(t, entry.Cursor, got.Cursor) +} + +func TestState_Get_Bad(t *testing.T) { + m := io.NewMockMedium() + s := NewState(m, "/state.json") + + got, ok := s.Get("nonexistent") + assert.False(t, ok) + assert.Nil(t, got) +} + +func TestState_SaveLoad_Good(t *testing.T) { + m := io.NewMockMedium() + s := NewState(m, "/state.json") + + now := time.Date(2025, 1, 15, 10, 30, 0, 0, time.UTC) + entry := &StateEntry{ + Source: "market:bitcoin", + LastRun: now, + Items: 100, + LastID: "btc-100", + } + + s.Set("market:bitcoin", entry) + + // Save state + err := s.Save() + assert.NoError(t, err) + + // Verify file was written + assert.True(t, m.IsFile("/state.json")) + + // Load into a new state instance + s2 := NewState(m, "/state.json") + err = s2.Load() + assert.NoError(t, err) + + got, ok := s2.Get("market:bitcoin") + assert.True(t, ok) + assert.Equal(t, "market:bitcoin", got.Source) + assert.Equal(t, 100, got.Items) + assert.Equal(t, "btc-100", got.LastID) + assert.True(t, now.Equal(got.LastRun)) +} + +func TestState_Load_Good_NoFile(t *testing.T) { + m := io.NewMockMedium() + s := NewState(m, "/nonexistent.json") + + // Loading when no file exists should not error + err := s.Load() + assert.NoError(t, err) + + // State should be empty + _, ok := s.Get("anything") + assert.False(t, ok) +} + +func TestState_Load_Bad_InvalidJSON(t *testing.T) { + m := io.NewMockMedium() + m.Files["/state.json"] = "not valid json" + + s := NewState(m, "/state.json") + err := s.Load() + assert.Error(t, err) +} + +func TestState_SaveLoad_Good_MultipleEntries(t *testing.T) { + m := io.NewMockMedium() + s := NewState(m, "/state.json") + + s.Set("source-a", &StateEntry{Source: "source-a", Items: 10}) + s.Set("source-b", &StateEntry{Source: "source-b", Items: 20}) + s.Set("source-c", &StateEntry{Source: "source-c", Items: 30}) + + err := s.Save() + assert.NoError(t, err) + + s2 := NewState(m, "/state.json") + err = s2.Load() + assert.NoError(t, err) + + a, ok := s2.Get("source-a") + assert.True(t, ok) + assert.Equal(t, 10, a.Items) + + b, ok := s2.Get("source-b") + assert.True(t, ok) + assert.Equal(t, 20, b.Items) + + c, ok := s2.Get("source-c") + assert.True(t, ok) + assert.Equal(t, 30, c.Items) +} + +func TestState_Set_Good_Overwrite(t *testing.T) { + m := io.NewMockMedium() + s := NewState(m, "/state.json") + + s.Set("source", &StateEntry{Source: "source", Items: 5}) + s.Set("source", &StateEntry{Source: "source", Items: 15}) + + got, ok := s.Get("source") + assert.True(t, ok) + assert.Equal(t, 15, got.Items) +} + +func TestNewState_Good(t *testing.T) { + m := io.NewMockMedium() + s := NewState(m, "/test/state.json") + + assert.NotNil(t, s) + assert.NotNil(t, s.entries) +} diff --git a/forge/client.go b/forge/client.go new file mode 100644 index 0000000..fb61c30 --- /dev/null +++ b/forge/client.go @@ -0,0 +1,73 @@ +// Package forge provides a thin wrapper around the Forgejo Go SDK +// for managing repositories, issues, and pull requests on a Forgejo instance. +// +// Authentication is resolved from config file, environment variables, or flag overrides: +// +// 1. ~/.core/config.yaml keys: forge.token, forge.url +// 2. FORGE_TOKEN + FORGE_URL environment variables (override config file) +// 3. Flag overrides via core forge config --url/--token (highest priority) +package forge + +import ( + forgejo "codeberg.org/mvdkleijn/forgejo-sdk/forgejo/v2" + + "forge.lthn.ai/core/go/pkg/log" +) + +// Client wraps the Forgejo SDK client with config-based auth. +type Client struct { + api *forgejo.Client + url string + token string +} + +// New creates a new Forgejo API client for the given URL and token. +func New(url, token string) (*Client, error) { + api, err := forgejo.NewClient(url, forgejo.SetToken(token)) + if err != nil { + return nil, log.E("forge.New", "failed to create client", err) + } + + return &Client{api: api, url: url, token: token}, nil +} + +// API exposes the underlying SDK client for direct access. +func (c *Client) API() *forgejo.Client { return c.api } + +// URL returns the Forgejo instance URL. +func (c *Client) URL() string { return c.url } + +// Token returns the Forgejo API token. +func (c *Client) Token() string { return c.token } + +// GetCurrentUser returns the authenticated user's information. +func (c *Client) GetCurrentUser() (*forgejo.User, error) { + user, _, err := c.api.GetMyUserInfo() + if err != nil { + return nil, log.E("forge.GetCurrentUser", "failed to get current user", err) + } + return user, nil +} + +// ForkRepo forks a repository. If org is non-empty, forks into that organisation. +func (c *Client) ForkRepo(owner, repo string, org string) (*forgejo.Repository, error) { + opts := forgejo.CreateForkOption{} + if org != "" { + opts.Organization = &org + } + + fork, _, err := c.api.CreateFork(owner, repo, opts) + if err != nil { + return nil, log.E("forge.ForkRepo", "failed to fork repository", err) + } + return fork, nil +} + +// CreatePullRequest creates a pull request on the given repository. +func (c *Client) CreatePullRequest(owner, repo string, opts forgejo.CreatePullRequestOption) (*forgejo.PullRequest, error) { + pr, _, err := c.api.CreatePullRequest(owner, repo, opts) + if err != nil { + return nil, log.E("forge.CreatePullRequest", "failed to create pull request", err) + } + return pr, nil +} diff --git a/forge/config.go b/forge/config.go new file mode 100644 index 0000000..941bbf3 --- /dev/null +++ b/forge/config.go @@ -0,0 +1,92 @@ +package forge + +import ( + "os" + + "forge.lthn.ai/core/go/pkg/config" + "forge.lthn.ai/core/go/pkg/log" +) + +const ( + // ConfigKeyURL is the config key for the Forgejo instance URL. + ConfigKeyURL = "forge.url" + // ConfigKeyToken is the config key for the Forgejo API token. + ConfigKeyToken = "forge.token" + + // DefaultURL is the default Forgejo instance URL. + DefaultURL = "http://localhost:4000" +) + +// NewFromConfig creates a Forgejo client using the standard config resolution: +// +// 1. ~/.core/config.yaml keys: forge.token, forge.url +// 2. FORGE_TOKEN + FORGE_URL environment variables (override config file) +// 3. Provided flag overrides (highest priority; pass empty to skip) +func NewFromConfig(flagURL, flagToken string) (*Client, error) { + url, token, err := ResolveConfig(flagURL, flagToken) + if err != nil { + return nil, err + } + + if token == "" { + return nil, log.E("forge.NewFromConfig", "no API token configured (set FORGE_TOKEN or run: core forge config --token TOKEN)", nil) + } + + return New(url, token) +} + +// ResolveConfig resolves the Forgejo URL and token from all config sources. +// Flag values take highest priority, then env vars, then config file. +func ResolveConfig(flagURL, flagToken string) (url, token string, err error) { + // Start with config file values + cfg, cfgErr := config.New() + if cfgErr == nil { + _ = cfg.Get(ConfigKeyURL, &url) + _ = cfg.Get(ConfigKeyToken, &token) + } + + // Overlay environment variables + if envURL := os.Getenv("FORGE_URL"); envURL != "" { + url = envURL + } + if envToken := os.Getenv("FORGE_TOKEN"); envToken != "" { + token = envToken + } + + // Overlay flag values (highest priority) + if flagURL != "" { + url = flagURL + } + if flagToken != "" { + token = flagToken + } + + // Default URL if nothing configured + if url == "" { + url = DefaultURL + } + + return url, token, nil +} + +// SaveConfig persists the Forgejo URL and/or token to the config file. +func SaveConfig(url, token string) error { + cfg, err := config.New() + if err != nil { + return log.E("forge.SaveConfig", "failed to load config", err) + } + + if url != "" { + if err := cfg.Set(ConfigKeyURL, url); err != nil { + return log.E("forge.SaveConfig", "failed to save URL", err) + } + } + + if token != "" { + if err := cfg.Set(ConfigKeyToken, token); err != nil { + return log.E("forge.SaveConfig", "failed to save token", err) + } + } + + return nil +} diff --git a/forge/issues.go b/forge/issues.go new file mode 100644 index 0000000..28a4c7e --- /dev/null +++ b/forge/issues.go @@ -0,0 +1,181 @@ +package forge + +import ( + forgejo "codeberg.org/mvdkleijn/forgejo-sdk/forgejo/v2" + + "forge.lthn.ai/core/go/pkg/log" +) + +// ListIssuesOpts configures issue listing. +type ListIssuesOpts struct { + State string // "open", "closed", "all" + Labels []string // filter by label names + Page int + Limit int +} + +// ListIssues returns issues for the given repository. +func (c *Client) ListIssues(owner, repo string, opts ListIssuesOpts) ([]*forgejo.Issue, error) { + state := forgejo.StateOpen + switch opts.State { + case "closed": + state = forgejo.StateClosed + case "all": + state = forgejo.StateAll + } + + limit := opts.Limit + if limit == 0 { + limit = 50 + } + + page := opts.Page + if page == 0 { + page = 1 + } + + listOpt := forgejo.ListIssueOption{ + ListOptions: forgejo.ListOptions{Page: page, PageSize: limit}, + State: state, + Type: forgejo.IssueTypeIssue, + Labels: opts.Labels, + } + + issues, _, err := c.api.ListRepoIssues(owner, repo, listOpt) + if err != nil { + return nil, log.E("forge.ListIssues", "failed to list issues", err) + } + + return issues, nil +} + +// GetIssue returns a single issue by number. +func (c *Client) GetIssue(owner, repo string, number int64) (*forgejo.Issue, error) { + issue, _, err := c.api.GetIssue(owner, repo, number) + if err != nil { + return nil, log.E("forge.GetIssue", "failed to get issue", err) + } + + return issue, nil +} + +// CreateIssue creates a new issue in the given repository. +func (c *Client) CreateIssue(owner, repo string, opts forgejo.CreateIssueOption) (*forgejo.Issue, error) { + issue, _, err := c.api.CreateIssue(owner, repo, opts) + if err != nil { + return nil, log.E("forge.CreateIssue", "failed to create issue", err) + } + + return issue, nil +} + +// EditIssue edits an existing issue. +func (c *Client) EditIssue(owner, repo string, number int64, opts forgejo.EditIssueOption) (*forgejo.Issue, error) { + issue, _, err := c.api.EditIssue(owner, repo, number, opts) + if err != nil { + return nil, log.E("forge.EditIssue", "failed to edit issue", err) + } + + return issue, nil +} + +// AssignIssue assigns an issue to the specified users. +func (c *Client) AssignIssue(owner, repo string, number int64, assignees []string) error { + _, _, err := c.api.EditIssue(owner, repo, number, forgejo.EditIssueOption{ + Assignees: assignees, + }) + if err != nil { + return log.E("forge.AssignIssue", "failed to assign issue", err) + } + return nil +} + +// ListPullRequests returns pull requests for the given repository. +func (c *Client) ListPullRequests(owner, repo string, state string) ([]*forgejo.PullRequest, error) { + st := forgejo.StateOpen + switch state { + case "closed": + st = forgejo.StateClosed + case "all": + st = forgejo.StateAll + } + + var all []*forgejo.PullRequest + page := 1 + + for { + prs, resp, err := c.api.ListRepoPullRequests(owner, repo, forgejo.ListPullRequestsOptions{ + ListOptions: forgejo.ListOptions{Page: page, PageSize: 50}, + State: st, + }) + if err != nil { + return nil, log.E("forge.ListPullRequests", "failed to list pull requests", err) + } + + all = append(all, prs...) + + if resp == nil || page >= resp.LastPage { + break + } + page++ + } + + return all, nil +} + +// GetPullRequest returns a single pull request by number. +func (c *Client) GetPullRequest(owner, repo string, number int64) (*forgejo.PullRequest, error) { + pr, _, err := c.api.GetPullRequest(owner, repo, number) + if err != nil { + return nil, log.E("forge.GetPullRequest", "failed to get pull request", err) + } + + return pr, nil +} + +// CreateIssueComment posts a comment on an issue or pull request. +func (c *Client) CreateIssueComment(owner, repo string, issue int64, body string) error { + _, _, err := c.api.CreateIssueComment(owner, repo, issue, forgejo.CreateIssueCommentOption{ + Body: body, + }) + if err != nil { + return log.E("forge.CreateIssueComment", "failed to create comment", err) + } + return nil +} + +// ListIssueComments returns comments for an issue. +func (c *Client) ListIssueComments(owner, repo string, number int64) ([]*forgejo.Comment, error) { + var all []*forgejo.Comment + page := 1 + + for { + comments, resp, err := c.api.ListIssueComments(owner, repo, number, forgejo.ListIssueCommentOptions{ + ListOptions: forgejo.ListOptions{Page: page, PageSize: 50}, + }) + if err != nil { + return nil, log.E("forge.ListIssueComments", "failed to list comments", err) + } + + all = append(all, comments...) + + if resp == nil || page >= resp.LastPage { + break + } + page++ + } + + return all, nil +} + +// CloseIssue closes an issue by setting its state to closed. +func (c *Client) CloseIssue(owner, repo string, number int64) error { + closed := forgejo.StateClosed + _, _, err := c.api.EditIssue(owner, repo, number, forgejo.EditIssueOption{ + State: &closed, + }) + if err != nil { + return log.E("forge.CloseIssue", "failed to close issue", err) + } + return nil +} diff --git a/forge/labels.go b/forge/labels.go new file mode 100644 index 0000000..1418d49 --- /dev/null +++ b/forge/labels.go @@ -0,0 +1,112 @@ +package forge + +import ( + "fmt" + "strings" + + forgejo "codeberg.org/mvdkleijn/forgejo-sdk/forgejo/v2" + + "forge.lthn.ai/core/go/pkg/log" +) + +// ListOrgLabels returns all labels for repos in the given organisation. +// Note: The Forgejo SDK does not have a dedicated org-level labels endpoint. +// This lists labels from the first repo found, which works when orgs use shared label sets. +// For org-wide label management, use ListRepoLabels with a specific repo. +func (c *Client) ListOrgLabels(org string) ([]*forgejo.Label, error) { + // Forgejo doesn't expose org-level labels via SDK — list repos and aggregate unique labels. + repos, err := c.ListOrgRepos(org) + if err != nil { + return nil, err + } + + if len(repos) == 0 { + return nil, nil + } + + // Use the first repo's labels as representative of the org's label set. + return c.ListRepoLabels(repos[0].Owner.UserName, repos[0].Name) +} + +// ListRepoLabels returns all labels for a repository. +func (c *Client) ListRepoLabels(owner, repo string) ([]*forgejo.Label, error) { + var all []*forgejo.Label + page := 1 + + for { + labels, resp, err := c.api.ListRepoLabels(owner, repo, forgejo.ListLabelsOptions{ + ListOptions: forgejo.ListOptions{Page: page, PageSize: 50}, + }) + if err != nil { + return nil, log.E("forge.ListRepoLabels", "failed to list repo labels", err) + } + + all = append(all, labels...) + + if resp == nil || page >= resp.LastPage { + break + } + page++ + } + + return all, nil +} + +// CreateRepoLabel creates a label on a repository. +func (c *Client) CreateRepoLabel(owner, repo string, opts forgejo.CreateLabelOption) (*forgejo.Label, error) { + label, _, err := c.api.CreateLabel(owner, repo, opts) + if err != nil { + return nil, log.E("forge.CreateRepoLabel", "failed to create repo label", err) + } + + return label, nil +} + +// GetLabelByName retrieves a specific label by name from a repository. +func (c *Client) GetLabelByName(owner, repo, name string) (*forgejo.Label, error) { + labels, err := c.ListRepoLabels(owner, repo) + if err != nil { + return nil, err + } + + for _, l := range labels { + if strings.EqualFold(l.Name, name) { + return l, nil + } + } + + return nil, fmt.Errorf("label %s not found in %s/%s", name, owner, repo) +} + +// EnsureLabel checks if a label exists, and creates it if it doesn't. +func (c *Client) EnsureLabel(owner, repo, name, color string) (*forgejo.Label, error) { + label, err := c.GetLabelByName(owner, repo, name) + if err == nil { + return label, nil + } + + return c.CreateRepoLabel(owner, repo, forgejo.CreateLabelOption{ + Name: name, + Color: color, + }) +} + +// AddIssueLabels adds labels to an issue. +func (c *Client) AddIssueLabels(owner, repo string, number int64, labelIDs []int64) error { + _, _, err := c.api.AddIssueLabels(owner, repo, number, forgejo.IssueLabelsOption{ + Labels: labelIDs, + }) + if err != nil { + return log.E("forge.AddIssueLabels", "failed to add labels to issue", err) + } + return nil +} + +// RemoveIssueLabel removes a label from an issue. +func (c *Client) RemoveIssueLabel(owner, repo string, number int64, labelID int64) error { + _, err := c.api.DeleteIssueLabel(owner, repo, number, labelID) + if err != nil { + return log.E("forge.RemoveIssueLabel", "failed to remove label from issue", err) + } + return nil +} diff --git a/forge/meta.go b/forge/meta.go new file mode 100644 index 0000000..df0930b --- /dev/null +++ b/forge/meta.go @@ -0,0 +1,144 @@ +package forge + +import ( + "time" + + forgejo "codeberg.org/mvdkleijn/forgejo-sdk/forgejo/v2" + + "forge.lthn.ai/core/go/pkg/log" +) + +// PRMeta holds structural signals from a pull request, +// used by the pipeline MetaReader for AI-driven workflows. +type PRMeta struct { + Number int64 + Title string + State string + Author string + Branch string + BaseBranch string + Labels []string + Assignees []string + IsMerged bool + CreatedAt time.Time + UpdatedAt time.Time + CommentCount int +} + +// Comment represents a comment with metadata. +type Comment struct { + ID int64 + Author string + Body string + CreatedAt time.Time + UpdatedAt time.Time +} + +const commentPageSize = 50 + +// GetPRMeta returns structural signals for a pull request. +// This is the Forgejo side of the dual MetaReader described in the pipeline design. +func (c *Client) GetPRMeta(owner, repo string, pr int64) (*PRMeta, error) { + pull, _, err := c.api.GetPullRequest(owner, repo, pr) + if err != nil { + return nil, log.E("forge.GetPRMeta", "failed to get PR metadata", err) + } + + meta := &PRMeta{ + Number: pull.Index, + Title: pull.Title, + State: string(pull.State), + Branch: pull.Head.Ref, + BaseBranch: pull.Base.Ref, + IsMerged: pull.HasMerged, + } + + if pull.Created != nil { + meta.CreatedAt = *pull.Created + } + if pull.Updated != nil { + meta.UpdatedAt = *pull.Updated + } + + if pull.Poster != nil { + meta.Author = pull.Poster.UserName + } + + for _, label := range pull.Labels { + meta.Labels = append(meta.Labels, label.Name) + } + + for _, assignee := range pull.Assignees { + meta.Assignees = append(meta.Assignees, assignee.UserName) + } + + // Fetch comment count from the issue side (PRs are issues in Forgejo). + // Paginate to get an accurate count. + count := 0 + page := 1 + for { + comments, _, listErr := c.api.ListIssueComments(owner, repo, pr, forgejo.ListIssueCommentOptions{ + ListOptions: forgejo.ListOptions{Page: page, PageSize: commentPageSize}, + }) + if listErr != nil { + break + } + count += len(comments) + if len(comments) < commentPageSize { + break + } + page++ + } + meta.CommentCount = count + + return meta, nil +} + +// GetCommentBodies returns all comment bodies for a pull request. +func (c *Client) GetCommentBodies(owner, repo string, pr int64) ([]Comment, error) { + var comments []Comment + page := 1 + + for { + raw, _, err := c.api.ListIssueComments(owner, repo, pr, forgejo.ListIssueCommentOptions{ + ListOptions: forgejo.ListOptions{Page: page, PageSize: commentPageSize}, + }) + if err != nil { + return nil, log.E("forge.GetCommentBodies", "failed to get PR comments", err) + } + + if len(raw) == 0 { + break + } + + for _, rc := range raw { + comment := Comment{ + ID: rc.ID, + Body: rc.Body, + CreatedAt: rc.Created, + UpdatedAt: rc.Updated, + } + if rc.Poster != nil { + comment.Author = rc.Poster.UserName + } + comments = append(comments, comment) + } + + if len(raw) < commentPageSize { + break + } + page++ + } + + return comments, nil +} + +// GetIssueBody returns the body text of an issue. +func (c *Client) GetIssueBody(owner, repo string, issue int64) (string, error) { + iss, _, err := c.api.GetIssue(owner, repo, issue) + if err != nil { + return "", log.E("forge.GetIssueBody", "failed to get issue body", err) + } + + return iss.Body, nil +} diff --git a/forge/orgs.go b/forge/orgs.go new file mode 100644 index 0000000..cce5097 --- /dev/null +++ b/forge/orgs.go @@ -0,0 +1,51 @@ +package forge + +import ( + forgejo "codeberg.org/mvdkleijn/forgejo-sdk/forgejo/v2" + + "forge.lthn.ai/core/go/pkg/log" +) + +// ListMyOrgs returns all organisations for the authenticated user. +func (c *Client) ListMyOrgs() ([]*forgejo.Organization, error) { + var all []*forgejo.Organization + page := 1 + + for { + orgs, resp, err := c.api.ListMyOrgs(forgejo.ListOrgsOptions{ + ListOptions: forgejo.ListOptions{Page: page, PageSize: 50}, + }) + if err != nil { + return nil, log.E("forge.ListMyOrgs", "failed to list orgs", err) + } + + all = append(all, orgs...) + + if resp == nil || page >= resp.LastPage { + break + } + page++ + } + + return all, nil +} + +// GetOrg returns a single organisation by name. +func (c *Client) GetOrg(name string) (*forgejo.Organization, error) { + org, _, err := c.api.GetOrg(name) + if err != nil { + return nil, log.E("forge.GetOrg", "failed to get org", err) + } + + return org, nil +} + +// CreateOrg creates a new organisation. +func (c *Client) CreateOrg(opts forgejo.CreateOrgOption) (*forgejo.Organization, error) { + org, _, err := c.api.CreateOrg(opts) + if err != nil { + return nil, log.E("forge.CreateOrg", "failed to create org", err) + } + + return org, nil +} diff --git a/forge/prs.go b/forge/prs.go new file mode 100644 index 0000000..465ebae --- /dev/null +++ b/forge/prs.go @@ -0,0 +1,109 @@ +package forge + +import ( + "bytes" + "encoding/json" + "fmt" + "net/http" + + forgejo "codeberg.org/mvdkleijn/forgejo-sdk/forgejo/v2" + + "forge.lthn.ai/core/go/pkg/log" +) + +// MergePullRequest merges a pull request with the given method ("squash", "rebase", "merge"). +func (c *Client) MergePullRequest(owner, repo string, index int64, method string) error { + style := forgejo.MergeStyleMerge + switch method { + case "squash": + style = forgejo.MergeStyleSquash + case "rebase": + style = forgejo.MergeStyleRebase + } + + merged, _, err := c.api.MergePullRequest(owner, repo, index, forgejo.MergePullRequestOption{ + Style: style, + DeleteBranchAfterMerge: true, + }) + if err != nil { + return log.E("forge.MergePullRequest", "failed to merge pull request", err) + } + if !merged { + return log.E("forge.MergePullRequest", fmt.Sprintf("merge returned false for %s/%s#%d", owner, repo, index), nil) + } + return nil +} + +// SetPRDraft sets or clears the draft status on a pull request. +// The Forgejo SDK v2.2.0 doesn't expose the draft field on EditPullRequestOption, +// so we use a raw HTTP PATCH request. +func (c *Client) SetPRDraft(owner, repo string, index int64, draft bool) error { + payload := map[string]bool{"draft": draft} + body, err := json.Marshal(payload) + if err != nil { + return log.E("forge.SetPRDraft", "marshal payload", err) + } + + url := fmt.Sprintf("%s/api/v1/repos/%s/%s/pulls/%d", c.url, owner, repo, index) + req, err := http.NewRequest(http.MethodPatch, url, bytes.NewReader(body)) + if err != nil { + return log.E("forge.SetPRDraft", "create request", err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "token "+c.token) + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return log.E("forge.SetPRDraft", "failed to update draft status", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return log.E("forge.SetPRDraft", fmt.Sprintf("unexpected status %d", resp.StatusCode), nil) + } + return nil +} + +// ListPRReviews returns all reviews for a pull request. +func (c *Client) ListPRReviews(owner, repo string, index int64) ([]*forgejo.PullReview, error) { + var all []*forgejo.PullReview + page := 1 + + for { + reviews, resp, err := c.api.ListPullReviews(owner, repo, index, forgejo.ListPullReviewsOptions{ + ListOptions: forgejo.ListOptions{Page: page, PageSize: 50}, + }) + if err != nil { + return nil, log.E("forge.ListPRReviews", "failed to list reviews", err) + } + + all = append(all, reviews...) + + if resp == nil || page >= resp.LastPage { + break + } + page++ + } + + return all, nil +} + +// GetCombinedStatus returns the combined commit status for a ref (SHA or branch). +func (c *Client) GetCombinedStatus(owner, repo string, ref string) (*forgejo.CombinedStatus, error) { + status, _, err := c.api.GetCombinedStatus(owner, repo, ref) + if err != nil { + return nil, log.E("forge.GetCombinedStatus", "failed to get combined status", err) + } + return status, nil +} + +// DismissReview dismisses a pull request review by ID. +func (c *Client) DismissReview(owner, repo string, index, reviewID int64, message string) error { + _, err := c.api.DismissPullReview(owner, repo, index, reviewID, forgejo.DismissPullReviewOptions{ + Message: message, + }) + if err != nil { + return log.E("forge.DismissReview", "failed to dismiss review", err) + } + return nil +} diff --git a/forge/repos.go b/forge/repos.go new file mode 100644 index 0000000..504d5db --- /dev/null +++ b/forge/repos.go @@ -0,0 +1,96 @@ +package forge + +import ( + forgejo "codeberg.org/mvdkleijn/forgejo-sdk/forgejo/v2" + + "forge.lthn.ai/core/go/pkg/log" +) + +// ListOrgRepos returns all repositories for the given organisation. +func (c *Client) ListOrgRepos(org string) ([]*forgejo.Repository, error) { + var all []*forgejo.Repository + page := 1 + + for { + repos, resp, err := c.api.ListOrgRepos(org, forgejo.ListOrgReposOptions{ + ListOptions: forgejo.ListOptions{Page: page, PageSize: 50}, + }) + if err != nil { + return nil, log.E("forge.ListOrgRepos", "failed to list org repos", err) + } + + all = append(all, repos...) + + if resp == nil || page >= resp.LastPage { + break + } + page++ + } + + return all, nil +} + +// ListUserRepos returns all repositories for the authenticated user. +func (c *Client) ListUserRepos() ([]*forgejo.Repository, error) { + var all []*forgejo.Repository + page := 1 + + for { + repos, resp, err := c.api.ListMyRepos(forgejo.ListReposOptions{ + ListOptions: forgejo.ListOptions{Page: page, PageSize: 50}, + }) + if err != nil { + return nil, log.E("forge.ListUserRepos", "failed to list user repos", err) + } + + all = append(all, repos...) + + if resp == nil || page >= resp.LastPage { + break + } + page++ + } + + return all, nil +} + +// GetRepo returns a single repository by owner and name. +func (c *Client) GetRepo(owner, name string) (*forgejo.Repository, error) { + repo, _, err := c.api.GetRepo(owner, name) + if err != nil { + return nil, log.E("forge.GetRepo", "failed to get repo", err) + } + + return repo, nil +} + +// CreateOrgRepo creates a new empty repository under an organisation. +func (c *Client) CreateOrgRepo(org string, opts forgejo.CreateRepoOption) (*forgejo.Repository, error) { + repo, _, err := c.api.CreateOrgRepo(org, opts) + if err != nil { + return nil, log.E("forge.CreateOrgRepo", "failed to create org repo", err) + } + + return repo, nil +} + +// DeleteRepo deletes a repository from Forgejo. +func (c *Client) DeleteRepo(owner, name string) error { + _, err := c.api.DeleteRepo(owner, name) + if err != nil { + return log.E("forge.DeleteRepo", "failed to delete repo", err) + } + + return nil +} + +// MigrateRepo migrates a repository from an external service using the Forgejo migration API. +// Unlike CreateMirror, this supports importing issues, labels, PRs, and more. +func (c *Client) MigrateRepo(opts forgejo.MigrateRepoOption) (*forgejo.Repository, error) { + repo, _, err := c.api.MigrateRepo(opts) + if err != nil { + return nil, log.E("forge.MigrateRepo", "failed to migrate repo", err) + } + + return repo, nil +} diff --git a/forge/webhooks.go b/forge/webhooks.go new file mode 100644 index 0000000..6d13b74 --- /dev/null +++ b/forge/webhooks.go @@ -0,0 +1,41 @@ +package forge + +import ( + forgejo "codeberg.org/mvdkleijn/forgejo-sdk/forgejo/v2" + + "forge.lthn.ai/core/go/pkg/log" +) + +// CreateRepoWebhook creates a webhook on a repository. +func (c *Client) CreateRepoWebhook(owner, repo string, opts forgejo.CreateHookOption) (*forgejo.Hook, error) { + hook, _, err := c.api.CreateRepoHook(owner, repo, opts) + if err != nil { + return nil, log.E("forge.CreateRepoWebhook", "failed to create repo webhook", err) + } + + return hook, nil +} + +// ListRepoWebhooks returns all webhooks for a repository. +func (c *Client) ListRepoWebhooks(owner, repo string) ([]*forgejo.Hook, error) { + var all []*forgejo.Hook + page := 1 + + for { + hooks, resp, err := c.api.ListRepoHooks(owner, repo, forgejo.ListHooksOptions{ + ListOptions: forgejo.ListOptions{Page: page, PageSize: 50}, + }) + if err != nil { + return nil, log.E("forge.ListRepoWebhooks", "failed to list repo webhooks", err) + } + + all = append(all, hooks...) + + if resp == nil || page >= resp.LastPage { + break + } + page++ + } + + return all, nil +} diff --git a/git/git.go b/git/git.go new file mode 100644 index 0000000..9f5460c --- /dev/null +++ b/git/git.go @@ -0,0 +1,265 @@ +// Package git provides utilities for git operations across multiple repositories. +package git + +import ( + "bytes" + "context" + "io" + "os" + "os/exec" + "strconv" + "strings" + "sync" +) + +// RepoStatus represents the git status of a single repository. +type RepoStatus struct { + Name string + Path string + Modified int + Untracked int + Staged int + Ahead int + Behind int + Branch string + Error error +} + +// IsDirty returns true if there are uncommitted changes. +func (s *RepoStatus) IsDirty() bool { + return s.Modified > 0 || s.Untracked > 0 || s.Staged > 0 +} + +// HasUnpushed returns true if there are commits to push. +func (s *RepoStatus) HasUnpushed() bool { + return s.Ahead > 0 +} + +// HasUnpulled returns true if there are commits to pull. +func (s *RepoStatus) HasUnpulled() bool { + return s.Behind > 0 +} + +// StatusOptions configures the status check. +type StatusOptions struct { + // Paths is a list of repo paths to check + Paths []string + // Names maps paths to display names + Names map[string]string +} + +// Status checks git status for multiple repositories in parallel. +func Status(ctx context.Context, opts StatusOptions) []RepoStatus { + var wg sync.WaitGroup + results := make([]RepoStatus, len(opts.Paths)) + + for i, path := range opts.Paths { + wg.Add(1) + go func(idx int, repoPath string) { + defer wg.Done() + name := opts.Names[repoPath] + if name == "" { + name = repoPath + } + results[idx] = getStatus(ctx, repoPath, name) + }(i, path) + } + + wg.Wait() + return results +} + +// getStatus gets the git status for a single repository. +func getStatus(ctx context.Context, path, name string) RepoStatus { + status := RepoStatus{ + Name: name, + Path: path, + } + + // Get current branch + branch, err := gitCommand(ctx, path, "rev-parse", "--abbrev-ref", "HEAD") + if err != nil { + status.Error = err + return status + } + status.Branch = strings.TrimSpace(branch) + + // Get porcelain status + porcelain, err := gitCommand(ctx, path, "status", "--porcelain") + if err != nil { + status.Error = err + return status + } + + // Parse status output + for _, line := range strings.Split(porcelain, "\n") { + if len(line) < 2 { + continue + } + x, y := line[0], line[1] + + // Untracked + if x == '?' && y == '?' { + status.Untracked++ + continue + } + + // Staged (index has changes) + if x == 'A' || x == 'D' || x == 'R' || x == 'M' { + status.Staged++ + } + + // Modified in working tree + if y == 'M' || y == 'D' { + status.Modified++ + } + } + + // Get ahead/behind counts + ahead, behind := getAheadBehind(ctx, path) + status.Ahead = ahead + status.Behind = behind + + return status +} + +// getAheadBehind returns the number of commits ahead and behind upstream. +func getAheadBehind(ctx context.Context, path string) (ahead, behind int) { + // Try to get ahead count + aheadStr, err := gitCommand(ctx, path, "rev-list", "--count", "@{u}..HEAD") + if err == nil { + ahead, _ = strconv.Atoi(strings.TrimSpace(aheadStr)) + } + + // Try to get behind count + behindStr, err := gitCommand(ctx, path, "rev-list", "--count", "HEAD..@{u}") + if err == nil { + behind, _ = strconv.Atoi(strings.TrimSpace(behindStr)) + } + + return ahead, behind +} + +// Push pushes commits for a single repository. +// Uses interactive mode to support SSH passphrase prompts. +func Push(ctx context.Context, path string) error { + return gitInteractive(ctx, path, "push") +} + +// Pull pulls changes for a single repository. +// Uses interactive mode to support SSH passphrase prompts. +func Pull(ctx context.Context, path string) error { + return gitInteractive(ctx, path, "pull", "--rebase") +} + +// IsNonFastForward checks if an error is a non-fast-forward rejection. +func IsNonFastForward(err error) bool { + if err == nil { + return false + } + msg := err.Error() + return strings.Contains(msg, "non-fast-forward") || + strings.Contains(msg, "fetch first") || + strings.Contains(msg, "tip of your current branch is behind") +} + +// gitInteractive runs a git command with terminal attached for user interaction. +func gitInteractive(ctx context.Context, dir string, args ...string) error { + cmd := exec.CommandContext(ctx, "git", args...) + cmd.Dir = dir + + // Connect to terminal for SSH passphrase prompts + cmd.Stdin = os.Stdin + cmd.Stdout = os.Stdout + + // Capture stderr for error reporting while also showing it + var stderr bytes.Buffer + cmd.Stderr = io.MultiWriter(os.Stderr, &stderr) + + if err := cmd.Run(); err != nil { + if stderr.Len() > 0 { + return &GitError{Err: err, Stderr: stderr.String()} + } + return err + } + + return nil +} + +// PushResult represents the result of a push operation. +type PushResult struct { + Name string + Path string + Success bool + Error error +} + +// PushMultiple pushes multiple repositories sequentially. +// Sequential because SSH passphrase prompts need user interaction. +func PushMultiple(ctx context.Context, paths []string, names map[string]string) []PushResult { + results := make([]PushResult, len(paths)) + + for i, path := range paths { + name := names[path] + if name == "" { + name = path + } + + result := PushResult{ + Name: name, + Path: path, + } + + err := Push(ctx, path) + if err != nil { + result.Error = err + } else { + result.Success = true + } + + results[i] = result + } + + return results +} + +// gitCommand runs a git command and returns stdout. +func gitCommand(ctx context.Context, dir string, args ...string) (string, error) { + cmd := exec.CommandContext(ctx, "git", args...) + cmd.Dir = dir + + var stdout, stderr bytes.Buffer + cmd.Stdout = &stdout + cmd.Stderr = &stderr + + if err := cmd.Run(); err != nil { + // Include stderr in error message for better diagnostics + if stderr.Len() > 0 { + return "", &GitError{Err: err, Stderr: stderr.String()} + } + return "", err + } + + return stdout.String(), nil +} + +// GitError wraps a git command error with stderr output. +type GitError struct { + Err error + Stderr string +} + +// Error returns the git error message, preferring stderr output. +func (e *GitError) Error() string { + // Return just the stderr message, trimmed + msg := strings.TrimSpace(e.Stderr) + if msg != "" { + return msg + } + return e.Err.Error() +} + +// Unwrap returns the underlying error for error chain inspection. +func (e *GitError) Unwrap() error { + return e.Err +} diff --git a/git/service.go b/git/service.go new file mode 100644 index 0000000..892d6fc --- /dev/null +++ b/git/service.go @@ -0,0 +1,126 @@ +package git + +import ( + "context" + + "forge.lthn.ai/core/go/pkg/framework" +) + +// Queries for git service + +// QueryStatus requests git status for paths. +type QueryStatus struct { + Paths []string + Names map[string]string +} + +// QueryDirtyRepos requests repos with uncommitted changes. +type QueryDirtyRepos struct{} + +// QueryAheadRepos requests repos with unpushed commits. +type QueryAheadRepos struct{} + +// Tasks for git service + +// TaskPush requests git push for a path. +type TaskPush struct { + Path string + Name string +} + +// TaskPull requests git pull for a path. +type TaskPull struct { + Path string + Name string +} + +// TaskPushMultiple requests git push for multiple paths. +type TaskPushMultiple struct { + Paths []string + Names map[string]string +} + +// ServiceOptions for configuring the git service. +type ServiceOptions struct { + WorkDir string +} + +// Service provides git operations as a Core service. +type Service struct { + *framework.ServiceRuntime[ServiceOptions] + lastStatus []RepoStatus +} + +// NewService creates a git service factory. +func NewService(opts ServiceOptions) func(*framework.Core) (any, error) { + return func(c *framework.Core) (any, error) { + return &Service{ + ServiceRuntime: framework.NewServiceRuntime(c, opts), + }, nil + } +} + +// OnStartup registers query and task handlers. +func (s *Service) OnStartup(ctx context.Context) error { + s.Core().RegisterQuery(s.handleQuery) + s.Core().RegisterTask(s.handleTask) + return nil +} + +func (s *Service) handleQuery(c *framework.Core, q framework.Query) (any, bool, error) { + switch m := q.(type) { + case QueryStatus: + statuses := Status(context.Background(), StatusOptions(m)) + s.lastStatus = statuses + return statuses, true, nil + + case QueryDirtyRepos: + return s.DirtyRepos(), true, nil + + case QueryAheadRepos: + return s.AheadRepos(), true, nil + } + return nil, false, nil +} + +func (s *Service) handleTask(c *framework.Core, t framework.Task) (any, bool, error) { + switch m := t.(type) { + case TaskPush: + err := Push(context.Background(), m.Path) + return nil, true, err + + case TaskPull: + err := Pull(context.Background(), m.Path) + return nil, true, err + + case TaskPushMultiple: + results := PushMultiple(context.Background(), m.Paths, m.Names) + return results, true, nil + } + return nil, false, nil +} + +// Status returns last status result. +func (s *Service) Status() []RepoStatus { return s.lastStatus } + +// DirtyRepos returns repos with uncommitted changes. +func (s *Service) DirtyRepos() []RepoStatus { + var dirty []RepoStatus + for _, st := range s.lastStatus { + if st.Error == nil && st.IsDirty() { + dirty = append(dirty, st) + } + } + return dirty +} + +// AheadRepos returns repos with unpushed commits. +func (s *Service) AheadRepos() []RepoStatus { + var ahead []RepoStatus + for _, st := range s.lastStatus { + if st.Error == nil && st.HasUnpushed() { + ahead = append(ahead, st) + } + } + return ahead +} diff --git a/gitea/client.go b/gitea/client.go new file mode 100644 index 0000000..d05ba21 --- /dev/null +++ b/gitea/client.go @@ -0,0 +1,37 @@ +// Package gitea provides a thin wrapper around the Gitea Go SDK +// for managing repositories, issues, and pull requests on a Gitea instance. +// +// Authentication is resolved from config file, environment variables, or flag overrides: +// +// 1. ~/.core/config.yaml keys: gitea.token, gitea.url +// 2. GITEA_TOKEN + GITEA_URL environment variables (override config file) +// 3. Flag overrides via core gitea config --url/--token (highest priority) +package gitea + +import ( + "code.gitea.io/sdk/gitea" + + "forge.lthn.ai/core/go/pkg/log" +) + +// Client wraps the Gitea SDK client with config-based auth. +type Client struct { + api *gitea.Client + url string +} + +// New creates a new Gitea API client for the given URL and token. +func New(url, token string) (*Client, error) { + api, err := gitea.NewClient(url, gitea.SetToken(token)) + if err != nil { + return nil, log.E("gitea.New", "failed to create client", err) + } + + return &Client{api: api, url: url}, nil +} + +// API exposes the underlying SDK client for direct access. +func (c *Client) API() *gitea.Client { return c.api } + +// URL returns the Gitea instance URL. +func (c *Client) URL() string { return c.url } diff --git a/gitea/config.go b/gitea/config.go new file mode 100644 index 0000000..7334854 --- /dev/null +++ b/gitea/config.go @@ -0,0 +1,92 @@ +package gitea + +import ( + "os" + + "forge.lthn.ai/core/go/pkg/config" + "forge.lthn.ai/core/go/pkg/log" +) + +const ( + // ConfigKeyURL is the config key for the Gitea instance URL. + ConfigKeyURL = "gitea.url" + // ConfigKeyToken is the config key for the Gitea API token. + ConfigKeyToken = "gitea.token" + + // DefaultURL is the default Gitea instance URL. + DefaultURL = "https://gitea.snider.dev" +) + +// NewFromConfig creates a Gitea client using the standard config resolution: +// +// 1. ~/.core/config.yaml keys: gitea.token, gitea.url +// 2. GITEA_TOKEN + GITEA_URL environment variables (override config file) +// 3. Provided flag overrides (highest priority; pass empty to skip) +func NewFromConfig(flagURL, flagToken string) (*Client, error) { + url, token, err := ResolveConfig(flagURL, flagToken) + if err != nil { + return nil, err + } + + if token == "" { + return nil, log.E("gitea.NewFromConfig", "no API token configured (set GITEA_TOKEN or run: core gitea config --token TOKEN)", nil) + } + + return New(url, token) +} + +// ResolveConfig resolves the Gitea URL and token from all config sources. +// Flag values take highest priority, then env vars, then config file. +func ResolveConfig(flagURL, flagToken string) (url, token string, err error) { + // Start with config file values + cfg, cfgErr := config.New() + if cfgErr == nil { + _ = cfg.Get(ConfigKeyURL, &url) + _ = cfg.Get(ConfigKeyToken, &token) + } + + // Overlay environment variables + if envURL := os.Getenv("GITEA_URL"); envURL != "" { + url = envURL + } + if envToken := os.Getenv("GITEA_TOKEN"); envToken != "" { + token = envToken + } + + // Overlay flag values (highest priority) + if flagURL != "" { + url = flagURL + } + if flagToken != "" { + token = flagToken + } + + // Default URL if nothing configured + if url == "" { + url = DefaultURL + } + + return url, token, nil +} + +// SaveConfig persists the Gitea URL and/or token to the config file. +func SaveConfig(url, token string) error { + cfg, err := config.New() + if err != nil { + return log.E("gitea.SaveConfig", "failed to load config", err) + } + + if url != "" { + if err := cfg.Set(ConfigKeyURL, url); err != nil { + return log.E("gitea.SaveConfig", "failed to save URL", err) + } + } + + if token != "" { + if err := cfg.Set(ConfigKeyToken, token); err != nil { + return log.E("gitea.SaveConfig", "failed to save token", err) + } + } + + return nil +} diff --git a/gitea/issues.go b/gitea/issues.go new file mode 100644 index 0000000..3f0d788 --- /dev/null +++ b/gitea/issues.go @@ -0,0 +1,109 @@ +package gitea + +import ( + "code.gitea.io/sdk/gitea" + + "forge.lthn.ai/core/go/pkg/log" +) + +// ListIssuesOpts configures issue listing. +type ListIssuesOpts struct { + State string // "open", "closed", "all" + Page int + Limit int +} + +// ListIssues returns issues for the given repository. +func (c *Client) ListIssues(owner, repo string, opts ListIssuesOpts) ([]*gitea.Issue, error) { + state := gitea.StateOpen + switch opts.State { + case "closed": + state = gitea.StateClosed + case "all": + state = gitea.StateAll + } + + limit := opts.Limit + if limit == 0 { + limit = 50 + } + + page := opts.Page + if page == 0 { + page = 1 + } + + issues, _, err := c.api.ListRepoIssues(owner, repo, gitea.ListIssueOption{ + ListOptions: gitea.ListOptions{Page: page, PageSize: limit}, + State: state, + Type: gitea.IssueTypeIssue, + }) + if err != nil { + return nil, log.E("gitea.ListIssues", "failed to list issues", err) + } + + return issues, nil +} + +// GetIssue returns a single issue by number. +func (c *Client) GetIssue(owner, repo string, number int64) (*gitea.Issue, error) { + issue, _, err := c.api.GetIssue(owner, repo, number) + if err != nil { + return nil, log.E("gitea.GetIssue", "failed to get issue", err) + } + + return issue, nil +} + +// CreateIssue creates a new issue in the given repository. +func (c *Client) CreateIssue(owner, repo string, opts gitea.CreateIssueOption) (*gitea.Issue, error) { + issue, _, err := c.api.CreateIssue(owner, repo, opts) + if err != nil { + return nil, log.E("gitea.CreateIssue", "failed to create issue", err) + } + + return issue, nil +} + +// ListPullRequests returns pull requests for the given repository. +func (c *Client) ListPullRequests(owner, repo string, state string) ([]*gitea.PullRequest, error) { + st := gitea.StateOpen + switch state { + case "closed": + st = gitea.StateClosed + case "all": + st = gitea.StateAll + } + + var all []*gitea.PullRequest + page := 1 + + for { + prs, resp, err := c.api.ListRepoPullRequests(owner, repo, gitea.ListPullRequestsOptions{ + ListOptions: gitea.ListOptions{Page: page, PageSize: 50}, + State: st, + }) + if err != nil { + return nil, log.E("gitea.ListPullRequests", "failed to list pull requests", err) + } + + all = append(all, prs...) + + if resp == nil || page >= resp.LastPage { + break + } + page++ + } + + return all, nil +} + +// GetPullRequest returns a single pull request by number. +func (c *Client) GetPullRequest(owner, repo string, number int64) (*gitea.PullRequest, error) { + pr, _, err := c.api.GetPullRequest(owner, repo, number) + if err != nil { + return nil, log.E("gitea.GetPullRequest", "failed to get pull request", err) + } + + return pr, nil +} diff --git a/gitea/meta.go b/gitea/meta.go new file mode 100644 index 0000000..5cb43ba --- /dev/null +++ b/gitea/meta.go @@ -0,0 +1,146 @@ +package gitea + +import ( + "time" + + "code.gitea.io/sdk/gitea" + + "forge.lthn.ai/core/go/pkg/log" +) + +// PRMeta holds structural signals from a pull request, +// used by the pipeline MetaReader for AI-driven workflows. +type PRMeta struct { + Number int64 + Title string + State string + Author string + Branch string + BaseBranch string + Labels []string + Assignees []string + IsMerged bool + CreatedAt time.Time + UpdatedAt time.Time + CommentCount int +} + +// Comment represents a comment with metadata. +type Comment struct { + ID int64 + Author string + Body string + CreatedAt time.Time + UpdatedAt time.Time +} + +const commentPageSize = 50 + +// GetPRMeta returns structural signals for a pull request. +// This is the Gitea side of the dual MetaReader described in the pipeline design. +func (c *Client) GetPRMeta(owner, repo string, pr int64) (*PRMeta, error) { + pull, _, err := c.api.GetPullRequest(owner, repo, pr) + if err != nil { + return nil, log.E("gitea.GetPRMeta", "failed to get PR metadata", err) + } + + meta := &PRMeta{ + Number: pull.Index, + Title: pull.Title, + State: string(pull.State), + Branch: pull.Head.Ref, + BaseBranch: pull.Base.Ref, + IsMerged: pull.HasMerged, + } + + if pull.Created != nil { + meta.CreatedAt = *pull.Created + } + if pull.Updated != nil { + meta.UpdatedAt = *pull.Updated + } + + if pull.Poster != nil { + meta.Author = pull.Poster.UserName + } + + for _, label := range pull.Labels { + meta.Labels = append(meta.Labels, label.Name) + } + + for _, assignee := range pull.Assignees { + meta.Assignees = append(meta.Assignees, assignee.UserName) + } + + // Fetch comment count from the issue side (PRs are issues in Gitea). + // Paginate to get an accurate count. + count := 0 + page := 1 + for { + comments, _, listErr := c.api.ListIssueComments(owner, repo, pr, gitea.ListIssueCommentOptions{ + ListOptions: gitea.ListOptions{Page: page, PageSize: commentPageSize}, + }) + if listErr != nil { + break + } + count += len(comments) + if len(comments) < commentPageSize { + break + } + page++ + } + meta.CommentCount = count + + return meta, nil +} + +// GetCommentBodies returns all comment bodies for a pull request. +// This reads full content, which is safe on the home lab Gitea instance. +func (c *Client) GetCommentBodies(owner, repo string, pr int64) ([]Comment, error) { + var comments []Comment + page := 1 + + for { + raw, _, err := c.api.ListIssueComments(owner, repo, pr, gitea.ListIssueCommentOptions{ + ListOptions: gitea.ListOptions{Page: page, PageSize: commentPageSize}, + }) + if err != nil { + return nil, log.E("gitea.GetCommentBodies", "failed to get PR comments", err) + } + + if len(raw) == 0 { + break + } + + for _, rc := range raw { + comment := Comment{ + ID: rc.ID, + Body: rc.Body, + CreatedAt: rc.Created, + UpdatedAt: rc.Updated, + } + if rc.Poster != nil { + comment.Author = rc.Poster.UserName + } + comments = append(comments, comment) + } + + if len(raw) < commentPageSize { + break + } + page++ + } + + return comments, nil +} + +// GetIssueBody returns the body text of an issue. +// This reads full content, which is safe on the home lab Gitea instance. +func (c *Client) GetIssueBody(owner, repo string, issue int64) (string, error) { + iss, _, err := c.api.GetIssue(owner, repo, issue) + if err != nil { + return "", log.E("gitea.GetIssueBody", "failed to get issue body", err) + } + + return iss.Body, nil +} diff --git a/gitea/repos.go b/gitea/repos.go new file mode 100644 index 0000000..e7380c3 --- /dev/null +++ b/gitea/repos.go @@ -0,0 +1,110 @@ +package gitea + +import ( + "code.gitea.io/sdk/gitea" + + "forge.lthn.ai/core/go/pkg/log" +) + +// ListOrgRepos returns all repositories for the given organisation. +func (c *Client) ListOrgRepos(org string) ([]*gitea.Repository, error) { + var all []*gitea.Repository + page := 1 + + for { + repos, resp, err := c.api.ListOrgRepos(org, gitea.ListOrgReposOptions{ + ListOptions: gitea.ListOptions{Page: page, PageSize: 50}, + }) + if err != nil { + return nil, log.E("gitea.ListOrgRepos", "failed to list org repos", err) + } + + all = append(all, repos...) + + if resp == nil || page >= resp.LastPage { + break + } + page++ + } + + return all, nil +} + +// ListUserRepos returns all repositories for the authenticated user. +func (c *Client) ListUserRepos() ([]*gitea.Repository, error) { + var all []*gitea.Repository + page := 1 + + for { + repos, resp, err := c.api.ListMyRepos(gitea.ListReposOptions{ + ListOptions: gitea.ListOptions{Page: page, PageSize: 50}, + }) + if err != nil { + return nil, log.E("gitea.ListUserRepos", "failed to list user repos", err) + } + + all = append(all, repos...) + + if resp == nil || page >= resp.LastPage { + break + } + page++ + } + + return all, nil +} + +// GetRepo returns a single repository by owner and name. +func (c *Client) GetRepo(owner, name string) (*gitea.Repository, error) { + repo, _, err := c.api.GetRepo(owner, name) + if err != nil { + return nil, log.E("gitea.GetRepo", "failed to get repo", err) + } + + return repo, nil +} + +// CreateMirror creates a mirror repository on Gitea from a GitHub clone URL. +// This uses the Gitea migration API to set up a pull mirror. +// If authToken is provided, it is used to authenticate against the source (e.g. for private GitHub repos). +func (c *Client) CreateMirror(owner, name, cloneURL, authToken string) (*gitea.Repository, error) { + opts := gitea.MigrateRepoOption{ + RepoName: name, + RepoOwner: owner, + CloneAddr: cloneURL, + Service: gitea.GitServiceGithub, + Mirror: true, + Description: "Mirror of " + cloneURL, + } + + if authToken != "" { + opts.AuthToken = authToken + } + + repo, _, err := c.api.MigrateRepo(opts) + if err != nil { + return nil, log.E("gitea.CreateMirror", "failed to create mirror", err) + } + + return repo, nil +} + +// DeleteRepo deletes a repository from Gitea. +func (c *Client) DeleteRepo(owner, name string) error { + _, err := c.api.DeleteRepo(owner, name) + if err != nil { + return log.E("gitea.DeleteRepo", "failed to delete repo", err) + } + + return nil +} + +// CreateOrgRepo creates a new empty repository under an organisation. +func (c *Client) CreateOrgRepo(org string, opts gitea.CreateRepoOption) (*gitea.Repository, error) { + repo, _, err := c.api.CreateOrgRepo(org, opts) + if err != nil { + return nil, log.E("gitea.CreateOrgRepo", "failed to create org repo", err) + } + + return repo, nil +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..8e85ad5 --- /dev/null +++ b/go.mod @@ -0,0 +1,36 @@ +module forge.lthn.ai/core/go-scm + +go 1.25.5 + +require ( + code.gitea.io/sdk/gitea v0.23.2 + codeberg.org/mvdkleijn/forgejo-sdk/forgejo/v2 v2.2.0 + forge.lthn.ai/core/go v0.0.0 + github.com/stretchr/testify v1.11.1 + golang.org/x/net v0.50.0 +) + +require ( + github.com/42wim/httpsig v1.2.3 // indirect + github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect + github.com/davidmz/go-pageant v1.0.2 // indirect + github.com/fsnotify/fsnotify v1.9.0 // indirect + github.com/go-fed/httpsig v1.1.0 // indirect + github.com/go-viper/mapstructure/v2 v2.5.0 // indirect + github.com/hashicorp/go-version v1.8.0 // indirect + github.com/pelletier/go-toml/v2 v2.2.4 // indirect + github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect + github.com/sagikazarmark/locafero v0.12.0 // indirect + github.com/spf13/afero v1.15.0 // indirect + github.com/spf13/cast v1.10.0 // indirect + github.com/spf13/pflag v1.0.10 // indirect + github.com/spf13/viper v1.21.0 // indirect + github.com/subosito/gotenv v1.6.0 // indirect + go.yaml.in/yaml/v3 v3.0.4 // indirect + golang.org/x/crypto v0.48.0 // indirect + golang.org/x/sys v0.41.0 // indirect + golang.org/x/text v0.34.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) + +replace forge.lthn.ai/core/go => ../go diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..5c132a1 --- /dev/null +++ b/go.sum @@ -0,0 +1,74 @@ +code.gitea.io/sdk/gitea v0.23.2 h1:iJB1FDmLegwfwjX8gotBDHdPSbk/ZR8V9VmEJaVsJYg= +code.gitea.io/sdk/gitea v0.23.2/go.mod h1:yyF5+GhljqvA30sRDreoyHILruNiy4ASufugzYg0VHM= +codeberg.org/mvdkleijn/forgejo-sdk/forgejo/v2 v2.2.0 h1:HTCWpzyWQOHDWt3LzI6/d2jvUDsw/vgGRWm/8BTvcqI= +codeberg.org/mvdkleijn/forgejo-sdk/forgejo/v2 v2.2.0/go.mod h1:ZglEEDj+qkxYUb+SQIeqGtFxQrbaMYqIOgahNKb7uxs= +github.com/42wim/httpsig v1.2.3 h1:xb0YyWhkYj57SPtfSttIobJUPJZB9as1nsfo7KWVcEs= +github.com/42wim/httpsig v1.2.3/go.mod h1:nZq9OlYKDrUBhptd77IHx4/sZZD+IxTBADvAPI9G/EM= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davidmz/go-pageant v1.0.2 h1:bPblRCh5jGU+Uptpz6LgMZGD5hJoOt7otgT454WvHn0= +github.com/davidmz/go-pageant v1.0.2/go.mod h1:P2EDDnMqIwG5Rrp05dTRITj9z2zpGcD9efWSkTNKLIE= +github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= +github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= +github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k= +github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0= +github.com/go-fed/httpsig v1.1.0 h1:9M+hb0jkEICD8/cAiNqEB66R87tTINszBRTjwjQzWcI= +github.com/go-fed/httpsig v1.1.0/go.mod h1:RCMrTZvN1bJYtofsG4rd5NaO5obxQ5xBkdiS7xsT7bM= +github.com/go-viper/mapstructure/v2 v2.5.0 h1:vM5IJoUAy3d7zRSVtIwQgBj7BiWtMPfmPEgAXnvj1Ro= +github.com/go-viper/mapstructure/v2 v2.5.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/hashicorp/go-version v1.8.0 h1:KAkNb1HAiZd1ukkxDFGmokVZe1Xy9HG6NUp+bPle2i4= +github.com/hashicorp/go-version v1.8.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09ZGVZPK5anwXA= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4= +github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= +github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc= +github.com/sagikazarmark/locafero v0.12.0 h1:/NQhBAkUb4+fH1jivKHWusDYFjMOOKU88eegjfxfHb4= +github.com/sagikazarmark/locafero v0.12.0/go.mod h1:sZh36u/YSZ918v0Io+U9ogLYQJ9tLLBmM4eneO6WwsI= +github.com/spf13/afero v1.15.0 h1:b/YBCLWAJdFWJTN9cLhiXXcD7mzKn9Dm86dNnfyQw1I= +github.com/spf13/afero v1.15.0/go.mod h1:NC2ByUVxtQs4b3sIUphxK0NioZnmxgyCrfzeuq8lxMg= +github.com/spf13/cast v1.10.0 h1:h2x0u2shc1QuLHfxi+cTJvs30+ZAHOGRic8uyGTDWxY= +github.com/spf13/cast v1.10.0/go.mod h1:jNfB8QC9IA6ZuY2ZjDp0KtFO2LZZlg4S/7bzP6qqeHo= +github.com/spf13/pflag v1.0.10 h1:4EBh2KAYBwaONj6b2Ye1GiHfwjqyROoF4RwYO+vPwFk= +github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +github.com/spf13/viper v1.21.0 h1:x5S+0EU27Lbphp4UKm1C+1oQO+rKx36vfCoaVebLFSU= +github.com/spf13/viper v1.21.0/go.mod h1:P0lhsswPGWD/1lZJ9ny3fYnVqxiegrlNrEmgLjbTCAY= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8= +github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU= +go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc= +go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.0.0-20210513164829-c07d793c2f9a/go.mod h1:P+XmwS30IXTQdn5tA2iutPOUgjI07+tq3H3K9MVA1s8= +golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts= +golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.50.0 h1:ucWh9eiCGyDR3vtzso0WMQinm2Dnt8cFMuQa9K33J60= +golang.org/x/net v0.50.0/go.mod h1:UgoSli3F/pBgdJBHCTc+tp3gmrU4XswgGRgtnwWTfyM= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k= +golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/term v0.40.0 h1:36e4zGLqU4yhjlmxEaagx2KuYbJq3EwY8K943ZsHcvg= +golang.org/x/term v0.40.0/go.mod h1:w2P8uVp06p2iyKKuvXIm7N/y0UCRt3UfJTfZ7oOpglM= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk= +golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/jobrunner/forgejo/signals.go b/jobrunner/forgejo/signals.go new file mode 100644 index 0000000..9f1e1ee --- /dev/null +++ b/jobrunner/forgejo/signals.go @@ -0,0 +1,114 @@ +package forgejo + +import ( + "regexp" + "strconv" + + forgejosdk "codeberg.org/mvdkleijn/forgejo-sdk/forgejo/v2" + + "forge.lthn.ai/core/go-scm/jobrunner" +) + +// epicChildRe matches checklist items: - [ ] #42 or - [x] #42 +var epicChildRe = regexp.MustCompile(`- \[([ x])\] #(\d+)`) + +// parseEpicChildren extracts child issue numbers from an epic body's checklist. +func parseEpicChildren(body string) (unchecked []int, checked []int) { + matches := epicChildRe.FindAllStringSubmatch(body, -1) + for _, m := range matches { + num, err := strconv.Atoi(m[2]) + if err != nil { + continue + } + if m[1] == "x" { + checked = append(checked, num) + } else { + unchecked = append(unchecked, num) + } + } + return unchecked, checked +} + +// linkedPRRe matches "#N" references in PR bodies. +var linkedPRRe = regexp.MustCompile(`#(\d+)`) + +// findLinkedPR finds the first PR whose body references the given issue number. +func findLinkedPR(prs []*forgejosdk.PullRequest, issueNumber int) *forgejosdk.PullRequest { + target := strconv.Itoa(issueNumber) + for _, pr := range prs { + matches := linkedPRRe.FindAllStringSubmatch(pr.Body, -1) + for _, m := range matches { + if m[1] == target { + return pr + } + } + } + return nil +} + +// mapPRState maps Forgejo's PR state and merged flag to a canonical string. +func mapPRState(pr *forgejosdk.PullRequest) string { + if pr.HasMerged { + return "MERGED" + } + switch pr.State { + case forgejosdk.StateOpen: + return "OPEN" + case forgejosdk.StateClosed: + return "CLOSED" + default: + return "CLOSED" + } +} + +// mapMergeable maps Forgejo's boolean Mergeable field to a canonical string. +func mapMergeable(pr *forgejosdk.PullRequest) string { + if pr.HasMerged { + return "UNKNOWN" + } + if pr.Mergeable { + return "MERGEABLE" + } + return "CONFLICTING" +} + +// mapCombinedStatus maps a Forgejo CombinedStatus to SUCCESS/FAILURE/PENDING. +func mapCombinedStatus(cs *forgejosdk.CombinedStatus) string { + if cs == nil || cs.TotalCount == 0 { + return "PENDING" + } + switch cs.State { + case forgejosdk.StatusSuccess: + return "SUCCESS" + case forgejosdk.StatusFailure, forgejosdk.StatusError: + return "FAILURE" + default: + return "PENDING" + } +} + +// buildSignal creates a PipelineSignal from Forgejo API data. +func buildSignal( + owner, repo string, + epicNumber, childNumber int, + pr *forgejosdk.PullRequest, + checkStatus string, +) *jobrunner.PipelineSignal { + sig := &jobrunner.PipelineSignal{ + EpicNumber: epicNumber, + ChildNumber: childNumber, + PRNumber: int(pr.Index), + RepoOwner: owner, + RepoName: repo, + PRState: mapPRState(pr), + IsDraft: false, // SDK v2.2.0 doesn't expose Draft; treat as non-draft + Mergeable: mapMergeable(pr), + CheckStatus: checkStatus, + } + + if pr.Head != nil { + sig.LastCommitSHA = pr.Head.Sha + } + + return sig +} diff --git a/jobrunner/forgejo/source.go b/jobrunner/forgejo/source.go new file mode 100644 index 0000000..92b2ba2 --- /dev/null +++ b/jobrunner/forgejo/source.go @@ -0,0 +1,173 @@ +package forgejo + +import ( + "context" + "fmt" + "strings" + + "forge.lthn.ai/core/go-scm/forge" + "forge.lthn.ai/core/go-scm/jobrunner" + "forge.lthn.ai/core/go/pkg/log" +) + +// Config configures a ForgejoSource. +type Config struct { + Repos []string // "owner/repo" format +} + +// ForgejoSource polls a Forgejo instance for pipeline signals from epic issues. +type ForgejoSource struct { + repos []string + forge *forge.Client +} + +// New creates a ForgejoSource using the given forge client. +func New(cfg Config, client *forge.Client) *ForgejoSource { + return &ForgejoSource{ + repos: cfg.Repos, + forge: client, + } +} + +// Name returns the source identifier. +func (s *ForgejoSource) Name() string { + return "forgejo" +} + +// Poll fetches epics and their linked PRs from all configured repositories, +// returning a PipelineSignal for each unchecked child that has a linked PR. +func (s *ForgejoSource) Poll(ctx context.Context) ([]*jobrunner.PipelineSignal, error) { + var signals []*jobrunner.PipelineSignal + + for _, repoFull := range s.repos { + owner, repo, err := splitRepo(repoFull) + if err != nil { + log.Error("invalid repo format", "repo", repoFull, "err", err) + continue + } + + repoSignals, err := s.pollRepo(ctx, owner, repo) + if err != nil { + log.Error("poll repo failed", "repo", repoFull, "err", err) + continue + } + + signals = append(signals, repoSignals...) + } + + return signals, nil +} + +// Report posts the action result as a comment on the epic issue. +func (s *ForgejoSource) Report(ctx context.Context, result *jobrunner.ActionResult) error { + if result == nil { + return nil + } + + status := "succeeded" + if !result.Success { + status = "failed" + } + + body := fmt.Sprintf("**jobrunner** `%s` %s for #%d (PR #%d)", result.Action, status, result.ChildNumber, result.PRNumber) + if result.Error != "" { + body += fmt.Sprintf("\n\n```\n%s\n```", result.Error) + } + + return s.forge.CreateIssueComment(result.RepoOwner, result.RepoName, int64(result.EpicNumber), body) +} + +// pollRepo fetches epics and PRs for a single repository. +func (s *ForgejoSource) pollRepo(_ context.Context, owner, repo string) ([]*jobrunner.PipelineSignal, error) { + // Fetch epic issues (label=epic, state=open). + issues, err := s.forge.ListIssues(owner, repo, forge.ListIssuesOpts{State: "open"}) + if err != nil { + return nil, log.E("forgejo.pollRepo", "fetch issues", err) + } + + // Filter to epics only. + var epics []epicInfo + for _, issue := range issues { + for _, label := range issue.Labels { + if label.Name == "epic" { + epics = append(epics, epicInfo{ + Number: int(issue.Index), + Body: issue.Body, + }) + break + } + } + } + + if len(epics) == 0 { + return nil, nil + } + + // Fetch all open PRs (and also merged/closed to catch MERGED state). + prs, err := s.forge.ListPullRequests(owner, repo, "all") + if err != nil { + return nil, log.E("forgejo.pollRepo", "fetch PRs", err) + } + + var signals []*jobrunner.PipelineSignal + + for _, epic := range epics { + unchecked, _ := parseEpicChildren(epic.Body) + for _, childNum := range unchecked { + pr := findLinkedPR(prs, childNum) + + if pr == nil { + // No PR yet — check if the child issue is assigned (needs coding). + childIssue, err := s.forge.GetIssue(owner, repo, int64(childNum)) + if err != nil { + log.Error("fetch child issue failed", "repo", owner+"/"+repo, "issue", childNum, "err", err) + continue + } + if len(childIssue.Assignees) > 0 && childIssue.Assignees[0].UserName != "" { + sig := &jobrunner.PipelineSignal{ + EpicNumber: epic.Number, + ChildNumber: childNum, + RepoOwner: owner, + RepoName: repo, + NeedsCoding: true, + Assignee: childIssue.Assignees[0].UserName, + IssueTitle: childIssue.Title, + IssueBody: childIssue.Body, + } + signals = append(signals, sig) + } + continue + } + + // Get combined commit status for the PR's head SHA. + checkStatus := "PENDING" + if pr.Head != nil && pr.Head.Sha != "" { + cs, err := s.forge.GetCombinedStatus(owner, repo, pr.Head.Sha) + if err != nil { + log.Error("fetch combined status failed", "repo", owner+"/"+repo, "sha", pr.Head.Sha, "err", err) + } else { + checkStatus = mapCombinedStatus(cs) + } + } + + sig := buildSignal(owner, repo, epic.Number, childNum, pr, checkStatus) + signals = append(signals, sig) + } + } + + return signals, nil +} + +type epicInfo struct { + Number int + Body string +} + +// splitRepo parses "owner/repo" into its components. +func splitRepo(full string) (string, string, error) { + parts := strings.SplitN(full, "/", 2) + if len(parts) != 2 || parts[0] == "" || parts[1] == "" { + return "", "", log.E("forgejo.splitRepo", fmt.Sprintf("expected owner/repo format, got %q", full), nil) + } + return parts[0], parts[1], nil +} diff --git a/jobrunner/forgejo/source_test.go b/jobrunner/forgejo/source_test.go new file mode 100644 index 0000000..5721201 --- /dev/null +++ b/jobrunner/forgejo/source_test.go @@ -0,0 +1,177 @@ +package forgejo + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "forge.lthn.ai/core/go-scm/forge" + "forge.lthn.ai/core/go-scm/jobrunner" +) + +// withVersion wraps an HTTP handler to serve the Forgejo /api/v1/version +// endpoint that the SDK calls during NewClient initialization. +func withVersion(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if strings.HasSuffix(r.URL.Path, "/version") { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"version":"9.0.0"}`)) + return + } + next.ServeHTTP(w, r) + }) +} + +func newTestClient(t *testing.T, url string) *forge.Client { + t.Helper() + client, err := forge.New(url, "test-token") + require.NoError(t, err) + return client +} + +func TestForgejoSource_Name(t *testing.T) { + s := New(Config{}, nil) + assert.Equal(t, "forgejo", s.Name()) +} + +func TestForgejoSource_Poll_Good(t *testing.T) { + srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + path := r.URL.Path + w.Header().Set("Content-Type", "application/json") + + switch { + // List issues — return one epic + case strings.Contains(path, "/issues"): + issues := []map[string]any{ + { + "number": 10, + "body": "## Tasks\n- [ ] #11\n- [x] #12\n", + "labels": []map[string]string{{"name": "epic"}}, + "state": "open", + }, + } + _ = json.NewEncoder(w).Encode(issues) + + // List PRs — return one open PR linked to #11 + case strings.Contains(path, "/pulls"): + prs := []map[string]any{ + { + "number": 20, + "body": "Fixes #11", + "state": "open", + "mergeable": true, + "merged": false, + "head": map[string]string{"sha": "abc123", "ref": "feature", "label": "feature"}, + }, + } + _ = json.NewEncoder(w).Encode(prs) + + // Combined status + case strings.Contains(path, "/status"): + status := map[string]any{ + "state": "success", + "total_count": 1, + "statuses": []map[string]any{{"status": "success", "context": "ci"}}, + } + _ = json.NewEncoder(w).Encode(status) + + default: + w.WriteHeader(http.StatusNotFound) + } + }))) + defer srv.Close() + + client := newTestClient(t, srv.URL) + s := New(Config{Repos: []string{"test-org/test-repo"}}, client) + + signals, err := s.Poll(context.Background()) + require.NoError(t, err) + + require.Len(t, signals, 1) + sig := signals[0] + assert.Equal(t, 10, sig.EpicNumber) + assert.Equal(t, 11, sig.ChildNumber) + assert.Equal(t, 20, sig.PRNumber) + assert.Equal(t, "OPEN", sig.PRState) + assert.Equal(t, "MERGEABLE", sig.Mergeable) + assert.Equal(t, "SUCCESS", sig.CheckStatus) + assert.Equal(t, "test-org", sig.RepoOwner) + assert.Equal(t, "test-repo", sig.RepoName) + assert.Equal(t, "abc123", sig.LastCommitSHA) +} + +func TestForgejoSource_Poll_NoEpics(t *testing.T) { + srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode([]any{}) + }))) + defer srv.Close() + + client := newTestClient(t, srv.URL) + s := New(Config{Repos: []string{"test-org/test-repo"}}, client) + + signals, err := s.Poll(context.Background()) + require.NoError(t, err) + assert.Empty(t, signals) +} + +func TestForgejoSource_Report_Good(t *testing.T) { + var capturedBody string + + srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + var body map[string]string + _ = json.NewDecoder(r.Body).Decode(&body) + capturedBody = body["body"] + _ = json.NewEncoder(w).Encode(map[string]any{"id": 1}) + }))) + defer srv.Close() + + client := newTestClient(t, srv.URL) + s := New(Config{}, client) + + result := &jobrunner.ActionResult{ + Action: "enable_auto_merge", + RepoOwner: "test-org", + RepoName: "test-repo", + EpicNumber: 10, + ChildNumber: 11, + PRNumber: 20, + Success: true, + } + + err := s.Report(context.Background(), result) + require.NoError(t, err) + assert.Contains(t, capturedBody, "enable_auto_merge") + assert.Contains(t, capturedBody, "succeeded") +} + +func TestParseEpicChildren(t *testing.T) { + body := "## Tasks\n- [x] #1\n- [ ] #7\n- [ ] #8\n- [x] #3\n" + unchecked, checked := parseEpicChildren(body) + assert.Equal(t, []int{7, 8}, unchecked) + assert.Equal(t, []int{1, 3}, checked) +} + +func TestFindLinkedPR(t *testing.T) { + assert.Nil(t, findLinkedPR(nil, 7)) +} + +func TestSplitRepo(t *testing.T) { + owner, repo, err := splitRepo("host-uk/core") + require.NoError(t, err) + assert.Equal(t, "host-uk", owner) + assert.Equal(t, "core", repo) + + _, _, err = splitRepo("invalid") + assert.Error(t, err) + + _, _, err = splitRepo("") + assert.Error(t, err) +} diff --git a/jobrunner/handlers/completion.go b/jobrunner/handlers/completion.go new file mode 100644 index 0000000..0355bda --- /dev/null +++ b/jobrunner/handlers/completion.go @@ -0,0 +1,87 @@ +package handlers + +import ( + "context" + "fmt" + "time" + + "forge.lthn.ai/core/go-scm/forge" + "forge.lthn.ai/core/go-scm/jobrunner" +) + +const ( + ColorAgentComplete = "#0e8a16" // Green +) + +// CompletionHandler manages issue state when an agent finishes work. +type CompletionHandler struct { + forge *forge.Client +} + +// NewCompletionHandler creates a handler for agent completion events. +func NewCompletionHandler(client *forge.Client) *CompletionHandler { + return &CompletionHandler{ + forge: client, + } +} + +// Name returns the handler identifier. +func (h *CompletionHandler) Name() string { + return "completion" +} + +// Match returns true if the signal indicates an agent has finished a task. +func (h *CompletionHandler) Match(signal *jobrunner.PipelineSignal) bool { + return signal.Type == "agent_completion" +} + +// Execute updates the issue labels based on the completion status. +func (h *CompletionHandler) Execute(ctx context.Context, signal *jobrunner.PipelineSignal) (*jobrunner.ActionResult, error) { + start := time.Now() + + // Remove in-progress label. + if inProgressLabel, err := h.forge.GetLabelByName(signal.RepoOwner, signal.RepoName, LabelInProgress); err == nil { + _ = h.forge.RemoveIssueLabel(signal.RepoOwner, signal.RepoName, int64(signal.ChildNumber), inProgressLabel.ID) + } + + if signal.Success { + completeLabel, err := h.forge.EnsureLabel(signal.RepoOwner, signal.RepoName, LabelAgentComplete, ColorAgentComplete) + if err != nil { + return nil, fmt.Errorf("ensure label %s: %w", LabelAgentComplete, err) + } + + if err := h.forge.AddIssueLabels(signal.RepoOwner, signal.RepoName, int64(signal.ChildNumber), []int64{completeLabel.ID}); err != nil { + return nil, fmt.Errorf("add completed label: %w", err) + } + + if signal.Message != "" { + _ = h.forge.CreateIssueComment(signal.RepoOwner, signal.RepoName, int64(signal.ChildNumber), signal.Message) + } + } else { + failedLabel, err := h.forge.EnsureLabel(signal.RepoOwner, signal.RepoName, LabelAgentFailed, ColorAgentFailed) + if err != nil { + return nil, fmt.Errorf("ensure label %s: %w", LabelAgentFailed, err) + } + + if err := h.forge.AddIssueLabels(signal.RepoOwner, signal.RepoName, int64(signal.ChildNumber), []int64{failedLabel.ID}); err != nil { + return nil, fmt.Errorf("add failed label: %w", err) + } + + msg := "Agent reported failure." + if signal.Error != "" { + msg += fmt.Sprintf("\n\nError: %s", signal.Error) + } + _ = h.forge.CreateIssueComment(signal.RepoOwner, signal.RepoName, int64(signal.ChildNumber), msg) + } + + return &jobrunner.ActionResult{ + Action: "completion", + RepoOwner: signal.RepoOwner, + RepoName: signal.RepoName, + EpicNumber: signal.EpicNumber, + ChildNumber: signal.ChildNumber, + Success: true, + Timestamp: time.Now(), + Duration: time.Since(start), + }, nil +} diff --git a/jobrunner/handlers/dispatch.go b/jobrunner/handlers/dispatch.go new file mode 100644 index 0000000..845b242 --- /dev/null +++ b/jobrunner/handlers/dispatch.go @@ -0,0 +1,290 @@ +package handlers + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "path/filepath" + "time" + + "forge.lthn.ai/core/go-scm/agentci" + "forge.lthn.ai/core/go-scm/forge" + "forge.lthn.ai/core/go-scm/jobrunner" + "forge.lthn.ai/core/go/pkg/log" +) + +const ( + LabelAgentReady = "agent-ready" + LabelInProgress = "in-progress" + LabelAgentFailed = "agent-failed" + LabelAgentComplete = "agent-completed" + + ColorInProgress = "#1d76db" // Blue + ColorAgentFailed = "#c0392b" // Red +) + +// DispatchTicket is the JSON payload written to the agent's queue. +// The ForgeToken is transferred separately via a .env file with 0600 permissions. +type DispatchTicket struct { + ID string `json:"id"` + RepoOwner string `json:"repo_owner"` + RepoName string `json:"repo_name"` + IssueNumber int `json:"issue_number"` + IssueTitle string `json:"issue_title"` + IssueBody string `json:"issue_body"` + TargetBranch string `json:"target_branch"` + EpicNumber int `json:"epic_number"` + ForgeURL string `json:"forge_url"` + ForgeUser string `json:"forgejo_user"` + Model string `json:"model,omitempty"` + Runner string `json:"runner,omitempty"` + VerifyModel string `json:"verify_model,omitempty"` + DualRun bool `json:"dual_run"` + CreatedAt string `json:"created_at"` +} + +// DispatchHandler dispatches coding work to remote agent machines via SSH. +type DispatchHandler struct { + forge *forge.Client + forgeURL string + token string + spinner *agentci.Spinner +} + +// NewDispatchHandler creates a handler that dispatches tickets to agent machines. +func NewDispatchHandler(client *forge.Client, forgeURL, token string, spinner *agentci.Spinner) *DispatchHandler { + return &DispatchHandler{ + forge: client, + forgeURL: forgeURL, + token: token, + spinner: spinner, + } +} + +// Name returns the handler identifier. +func (h *DispatchHandler) Name() string { + return "dispatch" +} + +// Match returns true for signals where a child issue needs coding (no PR yet) +// and the assignee is a known agent (by config key or Forgejo username). +func (h *DispatchHandler) Match(signal *jobrunner.PipelineSignal) bool { + if !signal.NeedsCoding { + return false + } + _, _, ok := h.spinner.FindByForgejoUser(signal.Assignee) + return ok +} + +// Execute creates a ticket JSON and transfers it securely to the agent's queue directory. +func (h *DispatchHandler) Execute(ctx context.Context, signal *jobrunner.PipelineSignal) (*jobrunner.ActionResult, error) { + start := time.Now() + + agentName, agent, ok := h.spinner.FindByForgejoUser(signal.Assignee) + if !ok { + return nil, fmt.Errorf("unknown agent: %s", signal.Assignee) + } + + // Sanitize inputs to prevent path traversal. + safeOwner, err := agentci.SanitizePath(signal.RepoOwner) + if err != nil { + return nil, fmt.Errorf("invalid repo owner: %w", err) + } + safeRepo, err := agentci.SanitizePath(signal.RepoName) + if err != nil { + return nil, fmt.Errorf("invalid repo name: %w", err) + } + + // Ensure in-progress label exists on repo. + inProgressLabel, err := h.forge.EnsureLabel(safeOwner, safeRepo, LabelInProgress, ColorInProgress) + if err != nil { + return nil, fmt.Errorf("ensure label %s: %w", LabelInProgress, err) + } + + // Check if already in progress to prevent double-dispatch. + issue, err := h.forge.GetIssue(safeOwner, safeRepo, int64(signal.ChildNumber)) + if err == nil { + for _, l := range issue.Labels { + if l.Name == LabelInProgress || l.Name == LabelAgentComplete { + log.Info("issue already processed, skipping", "issue", signal.ChildNumber, "label", l.Name) + return &jobrunner.ActionResult{ + Action: "dispatch", + Success: true, + Timestamp: time.Now(), + Duration: time.Since(start), + }, nil + } + } + } + + // Assign agent and add in-progress label. + if err := h.forge.AssignIssue(safeOwner, safeRepo, int64(signal.ChildNumber), []string{signal.Assignee}); err != nil { + log.Warn("failed to assign agent, continuing", "err", err) + } + + if err := h.forge.AddIssueLabels(safeOwner, safeRepo, int64(signal.ChildNumber), []int64{inProgressLabel.ID}); err != nil { + return nil, fmt.Errorf("add in-progress label: %w", err) + } + + // Remove agent-ready label if present. + if readyLabel, err := h.forge.GetLabelByName(safeOwner, safeRepo, LabelAgentReady); err == nil { + _ = h.forge.RemoveIssueLabel(safeOwner, safeRepo, int64(signal.ChildNumber), readyLabel.ID) + } + + // Clotho planning — determine execution mode. + runMode := h.spinner.DeterminePlan(signal, agentName) + verifyModel := "" + if runMode == agentci.ModeDual { + verifyModel = h.spinner.GetVerifierModel(agentName) + } + + // Build ticket. + targetBranch := "new" // TODO: resolve from epic or repo default + ticketID := fmt.Sprintf("%s-%s-%d-%d", safeOwner, safeRepo, signal.ChildNumber, time.Now().Unix()) + + ticket := DispatchTicket{ + ID: ticketID, + RepoOwner: safeOwner, + RepoName: safeRepo, + IssueNumber: signal.ChildNumber, + IssueTitle: signal.IssueTitle, + IssueBody: signal.IssueBody, + TargetBranch: targetBranch, + EpicNumber: signal.EpicNumber, + ForgeURL: h.forgeURL, + ForgeUser: signal.Assignee, + Model: agent.Model, + Runner: agent.Runner, + VerifyModel: verifyModel, + DualRun: runMode == agentci.ModeDual, + CreatedAt: time.Now().UTC().Format(time.RFC3339), + } + + ticketJSON, err := json.MarshalIndent(ticket, "", " ") + if err != nil { + h.failDispatch(signal, "Failed to marshal ticket JSON") + return nil, fmt.Errorf("marshal ticket: %w", err) + } + + // Check if ticket already exists on agent (dedup). + ticketName := fmt.Sprintf("ticket-%s-%s-%d.json", safeOwner, safeRepo, signal.ChildNumber) + if h.ticketExists(ctx, agent, ticketName) { + log.Info("ticket already queued, skipping", "ticket", ticketName, "agent", signal.Assignee) + return &jobrunner.ActionResult{ + Action: "dispatch", + RepoOwner: safeOwner, + RepoName: safeRepo, + EpicNumber: signal.EpicNumber, + ChildNumber: signal.ChildNumber, + Success: true, + Timestamp: time.Now(), + Duration: time.Since(start), + }, nil + } + + // Transfer ticket JSON. + remoteTicketPath := filepath.Join(agent.QueueDir, ticketName) + if err := h.secureTransfer(ctx, agent, remoteTicketPath, ticketJSON, 0644); err != nil { + h.failDispatch(signal, fmt.Sprintf("Ticket transfer failed: %v", err)) + return &jobrunner.ActionResult{ + Action: "dispatch", + RepoOwner: safeOwner, + RepoName: safeRepo, + EpicNumber: signal.EpicNumber, + ChildNumber: signal.ChildNumber, + Success: false, + Error: fmt.Sprintf("transfer ticket: %v", err), + Timestamp: time.Now(), + Duration: time.Since(start), + }, nil + } + + // Transfer token via separate .env file with 0600 permissions. + envContent := fmt.Sprintf("FORGE_TOKEN=%s\n", h.token) + remoteEnvPath := filepath.Join(agent.QueueDir, fmt.Sprintf(".env.%s", ticketID)) + if err := h.secureTransfer(ctx, agent, remoteEnvPath, []byte(envContent), 0600); err != nil { + // Clean up the ticket if env transfer fails. + _ = h.runRemote(ctx, agent, fmt.Sprintf("rm -f %s", agentci.EscapeShellArg(remoteTicketPath))) + h.failDispatch(signal, fmt.Sprintf("Token transfer failed: %v", err)) + return &jobrunner.ActionResult{ + Action: "dispatch", + RepoOwner: safeOwner, + RepoName: safeRepo, + EpicNumber: signal.EpicNumber, + ChildNumber: signal.ChildNumber, + Success: false, + Error: fmt.Sprintf("transfer token: %v", err), + Timestamp: time.Now(), + Duration: time.Since(start), + }, nil + } + + // Comment on issue. + modeStr := "Standard" + if runMode == agentci.ModeDual { + modeStr = "Clotho Verified (Dual Run)" + } + comment := fmt.Sprintf("Dispatched to **%s** agent queue.\nMode: **%s**", signal.Assignee, modeStr) + _ = h.forge.CreateIssueComment(safeOwner, safeRepo, int64(signal.ChildNumber), comment) + + return &jobrunner.ActionResult{ + Action: "dispatch", + RepoOwner: safeOwner, + RepoName: safeRepo, + EpicNumber: signal.EpicNumber, + ChildNumber: signal.ChildNumber, + Success: true, + Timestamp: time.Now(), + Duration: time.Since(start), + }, nil +} + +// failDispatch handles cleanup when dispatch fails (adds failed label, removes in-progress). +func (h *DispatchHandler) failDispatch(signal *jobrunner.PipelineSignal, reason string) { + if failedLabel, err := h.forge.EnsureLabel(signal.RepoOwner, signal.RepoName, LabelAgentFailed, ColorAgentFailed); err == nil { + _ = h.forge.AddIssueLabels(signal.RepoOwner, signal.RepoName, int64(signal.ChildNumber), []int64{failedLabel.ID}) + } + + if inProgressLabel, err := h.forge.GetLabelByName(signal.RepoOwner, signal.RepoName, LabelInProgress); err == nil { + _ = h.forge.RemoveIssueLabel(signal.RepoOwner, signal.RepoName, int64(signal.ChildNumber), inProgressLabel.ID) + } + + _ = h.forge.CreateIssueComment(signal.RepoOwner, signal.RepoName, int64(signal.ChildNumber), fmt.Sprintf("Agent dispatch failed: %s", reason)) +} + +// secureTransfer writes data to a remote path via SSH stdin, preventing command injection. +func (h *DispatchHandler) secureTransfer(ctx context.Context, agent agentci.AgentConfig, remotePath string, data []byte, mode int) error { + safeRemotePath := agentci.EscapeShellArg(remotePath) + remoteCmd := fmt.Sprintf("cat > %s && chmod %o %s", safeRemotePath, mode, safeRemotePath) + + cmd := agentci.SecureSSHCommand(agent.Host, remoteCmd) + cmd.Stdin = bytes.NewReader(data) + + output, err := cmd.CombinedOutput() + if err != nil { + return log.E("dispatch.transfer", fmt.Sprintf("ssh to %s failed: %s", agent.Host, string(output)), err) + } + return nil +} + +// runRemote executes a command on the agent via SSH. +func (h *DispatchHandler) runRemote(ctx context.Context, agent agentci.AgentConfig, cmdStr string) error { + cmd := agentci.SecureSSHCommand(agent.Host, cmdStr) + return cmd.Run() +} + +// ticketExists checks if a ticket file already exists in queue, active, or done. +func (h *DispatchHandler) ticketExists(ctx context.Context, agent agentci.AgentConfig, ticketName string) bool { + safeTicket, err := agentci.SanitizePath(ticketName) + if err != nil { + return false + } + qDir := agent.QueueDir + checkCmd := fmt.Sprintf( + "test -f %s/%s || test -f %s/../active/%s || test -f %s/../done/%s", + qDir, safeTicket, qDir, safeTicket, qDir, safeTicket, + ) + cmd := agentci.SecureSSHCommand(agent.Host, checkCmd) + return cmd.Run() == nil +} diff --git a/jobrunner/handlers/dispatch_test.go b/jobrunner/handlers/dispatch_test.go new file mode 100644 index 0000000..8456d38 --- /dev/null +++ b/jobrunner/handlers/dispatch_test.go @@ -0,0 +1,327 @@ +package handlers + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "forge.lthn.ai/core/go-scm/agentci" + "forge.lthn.ai/core/go-scm/jobrunner" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// newTestSpinner creates a Spinner with the given agents for testing. +func newTestSpinner(agents map[string]agentci.AgentConfig) *agentci.Spinner { + return agentci.NewSpinner(agentci.ClothoConfig{Strategy: "direct"}, agents) +} + +// --- Match tests --- + +func TestDispatch_Match_Good_NeedsCoding(t *testing.T) { + spinner := newTestSpinner(map[string]agentci.AgentConfig{ + "darbs-claude": {Host: "claude@192.168.0.201", QueueDir: "~/ai-work/queue", Active: true}, + }) + h := NewDispatchHandler(nil, "", "", spinner) + sig := &jobrunner.PipelineSignal{ + NeedsCoding: true, + Assignee: "darbs-claude", + } + assert.True(t, h.Match(sig)) +} + +func TestDispatch_Match_Good_MultipleAgents(t *testing.T) { + spinner := newTestSpinner(map[string]agentci.AgentConfig{ + "darbs-claude": {Host: "claude@192.168.0.201", QueueDir: "~/ai-work/queue", Active: true}, + "local-codex": {Host: "localhost", QueueDir: "~/ai-work/queue", Active: true}, + }) + h := NewDispatchHandler(nil, "", "", spinner) + sig := &jobrunner.PipelineSignal{ + NeedsCoding: true, + Assignee: "local-codex", + } + assert.True(t, h.Match(sig)) +} + +func TestDispatch_Match_Bad_HasPR(t *testing.T) { + spinner := newTestSpinner(map[string]agentci.AgentConfig{ + "darbs-claude": {Host: "claude@192.168.0.201", QueueDir: "~/ai-work/queue", Active: true}, + }) + h := NewDispatchHandler(nil, "", "", spinner) + sig := &jobrunner.PipelineSignal{ + NeedsCoding: false, + PRNumber: 7, + Assignee: "darbs-claude", + } + assert.False(t, h.Match(sig)) +} + +func TestDispatch_Match_Bad_UnknownAgent(t *testing.T) { + spinner := newTestSpinner(map[string]agentci.AgentConfig{ + "darbs-claude": {Host: "claude@192.168.0.201", QueueDir: "~/ai-work/queue", Active: true}, + }) + h := NewDispatchHandler(nil, "", "", spinner) + sig := &jobrunner.PipelineSignal{ + NeedsCoding: true, + Assignee: "unknown-user", + } + assert.False(t, h.Match(sig)) +} + +func TestDispatch_Match_Bad_NotAssigned(t *testing.T) { + spinner := newTestSpinner(map[string]agentci.AgentConfig{ + "darbs-claude": {Host: "claude@192.168.0.201", QueueDir: "~/ai-work/queue", Active: true}, + }) + h := NewDispatchHandler(nil, "", "", spinner) + sig := &jobrunner.PipelineSignal{ + NeedsCoding: true, + Assignee: "", + } + assert.False(t, h.Match(sig)) +} + +func TestDispatch_Match_Bad_EmptyAgentMap(t *testing.T) { + spinner := newTestSpinner(map[string]agentci.AgentConfig{}) + h := NewDispatchHandler(nil, "", "", spinner) + sig := &jobrunner.PipelineSignal{ + NeedsCoding: true, + Assignee: "darbs-claude", + } + assert.False(t, h.Match(sig)) +} + +// --- Name test --- + +func TestDispatch_Name_Good(t *testing.T) { + spinner := newTestSpinner(nil) + h := NewDispatchHandler(nil, "", "", spinner) + assert.Equal(t, "dispatch", h.Name()) +} + +// --- Execute tests --- + +func TestDispatch_Execute_Bad_UnknownAgent(t *testing.T) { + srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }))) + defer srv.Close() + + client := newTestForgeClient(t, srv.URL) + spinner := newTestSpinner(map[string]agentci.AgentConfig{ + "darbs-claude": {Host: "claude@192.168.0.201", QueueDir: "~/ai-work/queue", Active: true}, + }) + h := NewDispatchHandler(client, srv.URL, "test-token", spinner) + + sig := &jobrunner.PipelineSignal{ + NeedsCoding: true, + Assignee: "nonexistent-agent", + RepoOwner: "host-uk", + RepoName: "core", + ChildNumber: 1, + } + + _, err := h.Execute(context.Background(), sig) + require.Error(t, err) + assert.Contains(t, err.Error(), "unknown agent") +} + +func TestDispatch_TicketJSON_Good(t *testing.T) { + ticket := DispatchTicket{ + ID: "host-uk-core-5-1234567890", + RepoOwner: "host-uk", + RepoName: "core", + IssueNumber: 5, + IssueTitle: "Fix the thing", + IssueBody: "Please fix this bug", + TargetBranch: "new", + EpicNumber: 3, + ForgeURL: "https://forge.lthn.ai", + ForgeUser: "darbs-claude", + Model: "sonnet", + Runner: "claude", + DualRun: false, + CreatedAt: "2026-02-09T12:00:00Z", + } + + data, err := json.MarshalIndent(ticket, "", " ") + require.NoError(t, err) + + var decoded map[string]any + err = json.Unmarshal(data, &decoded) + require.NoError(t, err) + + assert.Equal(t, "host-uk-core-5-1234567890", decoded["id"]) + assert.Equal(t, "host-uk", decoded["repo_owner"]) + assert.Equal(t, "core", decoded["repo_name"]) + assert.Equal(t, float64(5), decoded["issue_number"]) + assert.Equal(t, "Fix the thing", decoded["issue_title"]) + assert.Equal(t, "Please fix this bug", decoded["issue_body"]) + assert.Equal(t, "new", decoded["target_branch"]) + assert.Equal(t, float64(3), decoded["epic_number"]) + assert.Equal(t, "https://forge.lthn.ai", decoded["forge_url"]) + assert.Equal(t, "darbs-claude", decoded["forgejo_user"]) + assert.Equal(t, "sonnet", decoded["model"]) + assert.Equal(t, "claude", decoded["runner"]) + // Token should NOT be present in the ticket. + _, hasToken := decoded["forge_token"] + assert.False(t, hasToken, "forge_token must not be in ticket JSON") +} + +func TestDispatch_TicketJSON_Good_DualRun(t *testing.T) { + ticket := DispatchTicket{ + ID: "test-dual", + RepoOwner: "host-uk", + RepoName: "core", + IssueNumber: 1, + ForgeURL: "https://forge.lthn.ai", + Model: "gemini-2.0-flash", + VerifyModel: "gemini-1.5-pro", + DualRun: true, + } + + data, err := json.Marshal(ticket) + require.NoError(t, err) + + var roundtrip DispatchTicket + err = json.Unmarshal(data, &roundtrip) + require.NoError(t, err) + assert.True(t, roundtrip.DualRun) + assert.Equal(t, "gemini-1.5-pro", roundtrip.VerifyModel) +} + +func TestDispatch_TicketJSON_Good_OmitsEmptyModelRunner(t *testing.T) { + ticket := DispatchTicket{ + ID: "test-1", + RepoOwner: "host-uk", + RepoName: "core", + IssueNumber: 1, + TargetBranch: "new", + ForgeURL: "https://forge.lthn.ai", + } + + data, err := json.MarshalIndent(ticket, "", " ") + require.NoError(t, err) + + var decoded map[string]any + err = json.Unmarshal(data, &decoded) + require.NoError(t, err) + + _, hasModel := decoded["model"] + _, hasRunner := decoded["runner"] + assert.False(t, hasModel, "model should be omitted when empty") + assert.False(t, hasRunner, "runner should be omitted when empty") +} + +func TestDispatch_TicketJSON_Good_ModelRunnerVariants(t *testing.T) { + tests := []struct { + name string + model string + runner string + }{ + {"claude-sonnet", "sonnet", "claude"}, + {"claude-opus", "opus", "claude"}, + {"codex-default", "", "codex"}, + {"gemini-default", "", "gemini"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ticket := DispatchTicket{ + ID: "test-" + tt.name, + RepoOwner: "host-uk", + RepoName: "core", + IssueNumber: 1, + TargetBranch: "new", + ForgeURL: "https://forge.lthn.ai", + Model: tt.model, + Runner: tt.runner, + } + + data, err := json.Marshal(ticket) + require.NoError(t, err) + + var roundtrip DispatchTicket + err = json.Unmarshal(data, &roundtrip) + require.NoError(t, err) + assert.Equal(t, tt.model, roundtrip.Model) + assert.Equal(t, tt.runner, roundtrip.Runner) + }) + } +} + +func TestDispatch_Execute_Good_PostsComment(t *testing.T) { + var commentPosted bool + var commentBody string + + srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + + switch { + case r.Method == http.MethodGet && r.URL.Path == "/api/v1/repos/host-uk/core/labels": + json.NewEncoder(w).Encode([]any{}) + return + + case r.Method == http.MethodPost && r.URL.Path == "/api/v1/repos/host-uk/core/labels": + json.NewEncoder(w).Encode(map[string]any{"id": 1, "name": "in-progress", "color": "#1d76db"}) + return + + case r.Method == http.MethodGet && r.URL.Path == "/api/v1/repos/host-uk/core/issues/5": + json.NewEncoder(w).Encode(map[string]any{"id": 5, "number": 5, "labels": []any{}, "title": "Test"}) + return + + case r.Method == http.MethodPatch && r.URL.Path == "/api/v1/repos/host-uk/core/issues/5": + json.NewEncoder(w).Encode(map[string]any{"id": 5, "number": 5}) + return + + case r.Method == http.MethodPost && r.URL.Path == "/api/v1/repos/host-uk/core/issues/5/labels": + json.NewEncoder(w).Encode([]any{map[string]any{"id": 1, "name": "in-progress"}}) + return + + case r.Method == http.MethodPost && r.URL.Path == "/api/v1/repos/host-uk/core/issues/5/comments": + commentPosted = true + var body map[string]string + _ = json.NewDecoder(r.Body).Decode(&body) + commentBody = body["body"] + json.NewEncoder(w).Encode(map[string]any{"id": 1, "body": body["body"]}) + return + } + + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(map[string]any{}) + }))) + defer srv.Close() + + client := newTestForgeClient(t, srv.URL) + + spinner := newTestSpinner(map[string]agentci.AgentConfig{ + "darbs-claude": {Host: "localhost", QueueDir: "/tmp/nonexistent-queue", Active: true}, + }) + h := NewDispatchHandler(client, srv.URL, "test-token", spinner) + + sig := &jobrunner.PipelineSignal{ + NeedsCoding: true, + Assignee: "darbs-claude", + RepoOwner: "host-uk", + RepoName: "core", + ChildNumber: 5, + EpicNumber: 3, + IssueTitle: "Test issue", + IssueBody: "Test body", + } + + result, err := h.Execute(context.Background(), sig) + require.NoError(t, err) + + assert.Equal(t, "dispatch", result.Action) + assert.Equal(t, "host-uk", result.RepoOwner) + assert.Equal(t, "core", result.RepoName) + assert.Equal(t, 3, result.EpicNumber) + assert.Equal(t, 5, result.ChildNumber) + + if result.Success { + assert.True(t, commentPosted) + assert.Contains(t, commentBody, "darbs-claude") + } +} diff --git a/jobrunner/handlers/enable_auto_merge.go b/jobrunner/handlers/enable_auto_merge.go new file mode 100644 index 0000000..dc919e7 --- /dev/null +++ b/jobrunner/handlers/enable_auto_merge.go @@ -0,0 +1,58 @@ +package handlers + +import ( + "context" + "fmt" + "time" + + "forge.lthn.ai/core/go-scm/forge" + "forge.lthn.ai/core/go-scm/jobrunner" +) + +// EnableAutoMergeHandler merges a PR that is ready using squash strategy. +type EnableAutoMergeHandler struct { + forge *forge.Client +} + +// NewEnableAutoMergeHandler creates a handler that merges ready PRs. +func NewEnableAutoMergeHandler(f *forge.Client) *EnableAutoMergeHandler { + return &EnableAutoMergeHandler{forge: f} +} + +// Name returns the handler identifier. +func (h *EnableAutoMergeHandler) Name() string { + return "enable_auto_merge" +} + +// Match returns true when the PR is open, not a draft, mergeable, checks +// are passing, and there are no unresolved review threads. +func (h *EnableAutoMergeHandler) Match(signal *jobrunner.PipelineSignal) bool { + return signal.PRState == "OPEN" && + !signal.IsDraft && + signal.Mergeable == "MERGEABLE" && + signal.CheckStatus == "SUCCESS" && + !signal.HasUnresolvedThreads() +} + +// Execute merges the pull request with squash strategy. +func (h *EnableAutoMergeHandler) Execute(ctx context.Context, signal *jobrunner.PipelineSignal) (*jobrunner.ActionResult, error) { + start := time.Now() + + err := h.forge.MergePullRequest(signal.RepoOwner, signal.RepoName, int64(signal.PRNumber), "squash") + + result := &jobrunner.ActionResult{ + Action: "enable_auto_merge", + RepoOwner: signal.RepoOwner, + RepoName: signal.RepoName, + PRNumber: signal.PRNumber, + Success: err == nil, + Timestamp: time.Now(), + Duration: time.Since(start), + } + + if err != nil { + result.Error = fmt.Sprintf("merge failed: %v", err) + } + + return result, nil +} diff --git a/jobrunner/handlers/enable_auto_merge_test.go b/jobrunner/handlers/enable_auto_merge_test.go new file mode 100644 index 0000000..a85130e --- /dev/null +++ b/jobrunner/handlers/enable_auto_merge_test.go @@ -0,0 +1,105 @@ +package handlers + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "forge.lthn.ai/core/go-scm/jobrunner" +) + +func TestEnableAutoMerge_Match_Good(t *testing.T) { + h := NewEnableAutoMergeHandler(nil) + sig := &jobrunner.PipelineSignal{ + PRState: "OPEN", + IsDraft: false, + Mergeable: "MERGEABLE", + CheckStatus: "SUCCESS", + ThreadsTotal: 0, + ThreadsResolved: 0, + } + assert.True(t, h.Match(sig)) +} + +func TestEnableAutoMerge_Match_Bad_Draft(t *testing.T) { + h := NewEnableAutoMergeHandler(nil) + sig := &jobrunner.PipelineSignal{ + PRState: "OPEN", + IsDraft: true, + Mergeable: "MERGEABLE", + CheckStatus: "SUCCESS", + ThreadsTotal: 0, + ThreadsResolved: 0, + } + assert.False(t, h.Match(sig)) +} + +func TestEnableAutoMerge_Match_Bad_UnresolvedThreads(t *testing.T) { + h := NewEnableAutoMergeHandler(nil) + sig := &jobrunner.PipelineSignal{ + PRState: "OPEN", + IsDraft: false, + Mergeable: "MERGEABLE", + CheckStatus: "SUCCESS", + ThreadsTotal: 5, + ThreadsResolved: 3, + } + assert.False(t, h.Match(sig)) +} + +func TestEnableAutoMerge_Execute_Good(t *testing.T) { + var capturedPath string + var capturedMethod string + + srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + capturedMethod = r.Method + capturedPath = r.URL.Path + w.WriteHeader(http.StatusOK) + }))) + defer srv.Close() + + client := newTestForgeClient(t, srv.URL) + + h := NewEnableAutoMergeHandler(client) + sig := &jobrunner.PipelineSignal{ + RepoOwner: "host-uk", + RepoName: "core-php", + PRNumber: 55, + } + + result, err := h.Execute(context.Background(), sig) + require.NoError(t, err) + + assert.True(t, result.Success) + assert.Equal(t, "enable_auto_merge", result.Action) + assert.Equal(t, http.MethodPost, capturedMethod) + assert.Equal(t, "/api/v1/repos/host-uk/core-php/pulls/55/merge", capturedPath) +} + +func TestEnableAutoMerge_Execute_Bad_MergeFailed(t *testing.T) { + srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusConflict) + _ = json.NewEncoder(w).Encode(map[string]string{"message": "merge conflict"}) + }))) + defer srv.Close() + + client := newTestForgeClient(t, srv.URL) + + h := NewEnableAutoMergeHandler(client) + sig := &jobrunner.PipelineSignal{ + RepoOwner: "host-uk", + RepoName: "core-php", + PRNumber: 55, + } + + result, err := h.Execute(context.Background(), sig) + require.NoError(t, err) + + assert.False(t, result.Success) + assert.Contains(t, result.Error, "merge failed") +} diff --git a/jobrunner/handlers/publish_draft.go b/jobrunner/handlers/publish_draft.go new file mode 100644 index 0000000..3604e9a --- /dev/null +++ b/jobrunner/handlers/publish_draft.go @@ -0,0 +1,55 @@ +package handlers + +import ( + "context" + "fmt" + "time" + + "forge.lthn.ai/core/go-scm/forge" + "forge.lthn.ai/core/go-scm/jobrunner" +) + +// PublishDraftHandler marks a draft PR as ready for review once its checks pass. +type PublishDraftHandler struct { + forge *forge.Client +} + +// NewPublishDraftHandler creates a handler that publishes draft PRs. +func NewPublishDraftHandler(f *forge.Client) *PublishDraftHandler { + return &PublishDraftHandler{forge: f} +} + +// Name returns the handler identifier. +func (h *PublishDraftHandler) Name() string { + return "publish_draft" +} + +// Match returns true when the PR is a draft, open, and all checks have passed. +func (h *PublishDraftHandler) Match(signal *jobrunner.PipelineSignal) bool { + return signal.IsDraft && + signal.PRState == "OPEN" && + signal.CheckStatus == "SUCCESS" +} + +// Execute marks the PR as no longer a draft. +func (h *PublishDraftHandler) Execute(ctx context.Context, signal *jobrunner.PipelineSignal) (*jobrunner.ActionResult, error) { + start := time.Now() + + err := h.forge.SetPRDraft(signal.RepoOwner, signal.RepoName, int64(signal.PRNumber), false) + + result := &jobrunner.ActionResult{ + Action: "publish_draft", + RepoOwner: signal.RepoOwner, + RepoName: signal.RepoName, + PRNumber: signal.PRNumber, + Success: err == nil, + Timestamp: time.Now(), + Duration: time.Since(start), + } + + if err != nil { + result.Error = fmt.Sprintf("publish draft failed: %v", err) + } + + return result, nil +} diff --git a/jobrunner/handlers/publish_draft_test.go b/jobrunner/handlers/publish_draft_test.go new file mode 100644 index 0000000..dd76e81 --- /dev/null +++ b/jobrunner/handlers/publish_draft_test.go @@ -0,0 +1,84 @@ +package handlers + +import ( + "context" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "forge.lthn.ai/core/go-scm/jobrunner" +) + +func TestPublishDraft_Match_Good(t *testing.T) { + h := NewPublishDraftHandler(nil) + sig := &jobrunner.PipelineSignal{ + IsDraft: true, + PRState: "OPEN", + CheckStatus: "SUCCESS", + } + assert.True(t, h.Match(sig)) +} + +func TestPublishDraft_Match_Bad_NotDraft(t *testing.T) { + h := NewPublishDraftHandler(nil) + sig := &jobrunner.PipelineSignal{ + IsDraft: false, + PRState: "OPEN", + CheckStatus: "SUCCESS", + } + assert.False(t, h.Match(sig)) +} + +func TestPublishDraft_Match_Bad_ChecksFailing(t *testing.T) { + h := NewPublishDraftHandler(nil) + sig := &jobrunner.PipelineSignal{ + IsDraft: true, + PRState: "OPEN", + CheckStatus: "FAILURE", + } + assert.False(t, h.Match(sig)) +} + +func TestPublishDraft_Execute_Good(t *testing.T) { + var capturedMethod string + var capturedPath string + var capturedBody string + + srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + capturedMethod = r.Method + capturedPath = r.URL.Path + b, _ := io.ReadAll(r.Body) + capturedBody = string(b) + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{}`)) + }))) + defer srv.Close() + + client := newTestForgeClient(t, srv.URL) + + h := NewPublishDraftHandler(client) + sig := &jobrunner.PipelineSignal{ + RepoOwner: "host-uk", + RepoName: "core-php", + PRNumber: 42, + IsDraft: true, + PRState: "OPEN", + } + + result, err := h.Execute(context.Background(), sig) + require.NoError(t, err) + + assert.Equal(t, http.MethodPatch, capturedMethod) + assert.Equal(t, "/api/v1/repos/host-uk/core-php/pulls/42", capturedPath) + assert.Contains(t, capturedBody, `"draft":false`) + + assert.True(t, result.Success) + assert.Equal(t, "publish_draft", result.Action) + assert.Equal(t, "host-uk", result.RepoOwner) + assert.Equal(t, "core-php", result.RepoName) + assert.Equal(t, 42, result.PRNumber) +} diff --git a/jobrunner/handlers/resolve_threads.go b/jobrunner/handlers/resolve_threads.go new file mode 100644 index 0000000..acb8477 --- /dev/null +++ b/jobrunner/handlers/resolve_threads.go @@ -0,0 +1,79 @@ +package handlers + +import ( + "context" + "fmt" + "time" + + forgejosdk "codeberg.org/mvdkleijn/forgejo-sdk/forgejo/v2" + + "forge.lthn.ai/core/go-scm/forge" + "forge.lthn.ai/core/go-scm/jobrunner" +) + +// DismissReviewsHandler dismisses stale "request changes" reviews on a PR. +// This replaces the GitHub-only ResolveThreadsHandler because Forgejo does +// not have a thread resolution API. +type DismissReviewsHandler struct { + forge *forge.Client +} + +// NewDismissReviewsHandler creates a handler that dismisses stale reviews. +func NewDismissReviewsHandler(f *forge.Client) *DismissReviewsHandler { + return &DismissReviewsHandler{forge: f} +} + +// Name returns the handler identifier. +func (h *DismissReviewsHandler) Name() string { + return "dismiss_reviews" +} + +// Match returns true when the PR is open and has unresolved review threads. +func (h *DismissReviewsHandler) Match(signal *jobrunner.PipelineSignal) bool { + return signal.PRState == "OPEN" && signal.HasUnresolvedThreads() +} + +// Execute dismisses stale "request changes" reviews on the PR. +func (h *DismissReviewsHandler) Execute(ctx context.Context, signal *jobrunner.PipelineSignal) (*jobrunner.ActionResult, error) { + start := time.Now() + + reviews, err := h.forge.ListPRReviews(signal.RepoOwner, signal.RepoName, int64(signal.PRNumber)) + if err != nil { + return nil, fmt.Errorf("dismiss_reviews: list reviews: %w", err) + } + + var dismissErrors []string + dismissed := 0 + for _, review := range reviews { + if review.State != forgejosdk.ReviewStateRequestChanges || review.Dismissed || !review.Stale { + continue + } + + if err := h.forge.DismissReview( + signal.RepoOwner, signal.RepoName, + int64(signal.PRNumber), review.ID, + "Automatically dismissed: review is stale after new commits", + ); err != nil { + dismissErrors = append(dismissErrors, err.Error()) + } else { + dismissed++ + } + } + + result := &jobrunner.ActionResult{ + Action: "dismiss_reviews", + RepoOwner: signal.RepoOwner, + RepoName: signal.RepoName, + PRNumber: signal.PRNumber, + Success: len(dismissErrors) == 0, + Timestamp: time.Now(), + Duration: time.Since(start), + } + + if len(dismissErrors) > 0 { + result.Error = fmt.Sprintf("failed to dismiss %d review(s): %s", + len(dismissErrors), dismissErrors[0]) + } + + return result, nil +} diff --git a/jobrunner/handlers/resolve_threads_test.go b/jobrunner/handlers/resolve_threads_test.go new file mode 100644 index 0000000..4b09208 --- /dev/null +++ b/jobrunner/handlers/resolve_threads_test.go @@ -0,0 +1,91 @@ +package handlers + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "forge.lthn.ai/core/go-scm/jobrunner" +) + +func TestDismissReviews_Match_Good(t *testing.T) { + h := NewDismissReviewsHandler(nil) + sig := &jobrunner.PipelineSignal{ + PRState: "OPEN", + ThreadsTotal: 4, + ThreadsResolved: 2, + } + assert.True(t, h.Match(sig)) +} + +func TestDismissReviews_Match_Bad_AllResolved(t *testing.T) { + h := NewDismissReviewsHandler(nil) + sig := &jobrunner.PipelineSignal{ + PRState: "OPEN", + ThreadsTotal: 3, + ThreadsResolved: 3, + } + assert.False(t, h.Match(sig)) +} + +func TestDismissReviews_Execute_Good(t *testing.T) { + callCount := 0 + + srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + callCount++ + w.Header().Set("Content-Type", "application/json") + + // ListPullReviews (GET) + if r.Method == http.MethodGet { + reviews := []map[string]any{ + { + "id": 1, "state": "REQUEST_CHANGES", "dismissed": false, "stale": true, + "body": "fix this", "commit_id": "abc123", + }, + { + "id": 2, "state": "APPROVED", "dismissed": false, "stale": false, + "body": "looks good", "commit_id": "abc123", + }, + { + "id": 3, "state": "REQUEST_CHANGES", "dismissed": false, "stale": true, + "body": "needs work", "commit_id": "abc123", + }, + } + _ = json.NewEncoder(w).Encode(reviews) + return + } + + // DismissPullReview (POST to dismissals endpoint) + w.WriteHeader(http.StatusOK) + }))) + defer srv.Close() + + client := newTestForgeClient(t, srv.URL) + + h := NewDismissReviewsHandler(client) + sig := &jobrunner.PipelineSignal{ + RepoOwner: "host-uk", + RepoName: "core-admin", + PRNumber: 33, + PRState: "OPEN", + ThreadsTotal: 3, + ThreadsResolved: 1, + } + + result, err := h.Execute(context.Background(), sig) + require.NoError(t, err) + + assert.True(t, result.Success) + assert.Equal(t, "dismiss_reviews", result.Action) + assert.Equal(t, "host-uk", result.RepoOwner) + assert.Equal(t, "core-admin", result.RepoName) + assert.Equal(t, 33, result.PRNumber) + + // 1 list + 2 dismiss (reviews #1 and #3 are stale REQUEST_CHANGES) + assert.Equal(t, 3, callCount) +} diff --git a/jobrunner/handlers/send_fix_command.go b/jobrunner/handlers/send_fix_command.go new file mode 100644 index 0000000..bfb7202 --- /dev/null +++ b/jobrunner/handlers/send_fix_command.go @@ -0,0 +1,74 @@ +package handlers + +import ( + "context" + "fmt" + "time" + + "forge.lthn.ai/core/go-scm/forge" + "forge.lthn.ai/core/go-scm/jobrunner" +) + +// SendFixCommandHandler posts a comment on a PR asking for conflict or +// review fixes. +type SendFixCommandHandler struct { + forge *forge.Client +} + +// NewSendFixCommandHandler creates a handler that posts fix commands. +func NewSendFixCommandHandler(f *forge.Client) *SendFixCommandHandler { + return &SendFixCommandHandler{forge: f} +} + +// Name returns the handler identifier. +func (h *SendFixCommandHandler) Name() string { + return "send_fix_command" +} + +// Match returns true when the PR is open and either has merge conflicts or +// has unresolved threads with failing checks. +func (h *SendFixCommandHandler) Match(signal *jobrunner.PipelineSignal) bool { + if signal.PRState != "OPEN" { + return false + } + if signal.Mergeable == "CONFLICTING" { + return true + } + if signal.HasUnresolvedThreads() && signal.CheckStatus == "FAILURE" { + return true + } + return false +} + +// Execute posts a comment on the PR asking for a fix. +func (h *SendFixCommandHandler) Execute(ctx context.Context, signal *jobrunner.PipelineSignal) (*jobrunner.ActionResult, error) { + start := time.Now() + + var message string + if signal.Mergeable == "CONFLICTING" { + message = "Can you fix the merge conflict?" + } else { + message = "Can you fix the code reviews?" + } + + err := h.forge.CreateIssueComment( + signal.RepoOwner, signal.RepoName, + int64(signal.PRNumber), message, + ) + + result := &jobrunner.ActionResult{ + Action: "send_fix_command", + RepoOwner: signal.RepoOwner, + RepoName: signal.RepoName, + PRNumber: signal.PRNumber, + Success: err == nil, + Timestamp: time.Now(), + Duration: time.Since(start), + } + + if err != nil { + result.Error = fmt.Sprintf("post comment failed: %v", err) + } + + return result, nil +} diff --git a/jobrunner/handlers/send_fix_command_test.go b/jobrunner/handlers/send_fix_command_test.go new file mode 100644 index 0000000..871c739 --- /dev/null +++ b/jobrunner/handlers/send_fix_command_test.go @@ -0,0 +1,87 @@ +package handlers + +import ( + "context" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "forge.lthn.ai/core/go-scm/jobrunner" +) + +func TestSendFixCommand_Match_Good_Conflicting(t *testing.T) { + h := NewSendFixCommandHandler(nil) + sig := &jobrunner.PipelineSignal{ + PRState: "OPEN", + Mergeable: "CONFLICTING", + } + assert.True(t, h.Match(sig)) +} + +func TestSendFixCommand_Match_Good_UnresolvedThreads(t *testing.T) { + h := NewSendFixCommandHandler(nil) + sig := &jobrunner.PipelineSignal{ + PRState: "OPEN", + Mergeable: "MERGEABLE", + CheckStatus: "FAILURE", + ThreadsTotal: 3, + ThreadsResolved: 1, + } + assert.True(t, h.Match(sig)) +} + +func TestSendFixCommand_Match_Bad_Clean(t *testing.T) { + h := NewSendFixCommandHandler(nil) + sig := &jobrunner.PipelineSignal{ + PRState: "OPEN", + Mergeable: "MERGEABLE", + CheckStatus: "SUCCESS", + ThreadsTotal: 2, + ThreadsResolved: 2, + } + assert.False(t, h.Match(sig)) +} + +func TestSendFixCommand_Execute_Good_Conflict(t *testing.T) { + var capturedMethod string + var capturedPath string + var capturedBody string + + srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + capturedMethod = r.Method + capturedPath = r.URL.Path + b, _ := io.ReadAll(r.Body) + capturedBody = string(b) + w.WriteHeader(http.StatusCreated) + _, _ = w.Write([]byte(`{"id":1}`)) + }))) + defer srv.Close() + + client := newTestForgeClient(t, srv.URL) + + h := NewSendFixCommandHandler(client) + sig := &jobrunner.PipelineSignal{ + RepoOwner: "host-uk", + RepoName: "core-tenant", + PRNumber: 17, + PRState: "OPEN", + Mergeable: "CONFLICTING", + } + + result, err := h.Execute(context.Background(), sig) + require.NoError(t, err) + + assert.Equal(t, http.MethodPost, capturedMethod) + assert.Equal(t, "/api/v1/repos/host-uk/core-tenant/issues/17/comments", capturedPath) + assert.Contains(t, capturedBody, "fix the merge conflict") + + assert.True(t, result.Success) + assert.Equal(t, "send_fix_command", result.Action) + assert.Equal(t, "host-uk", result.RepoOwner) + assert.Equal(t, "core-tenant", result.RepoName) + assert.Equal(t, 17, result.PRNumber) +} diff --git a/jobrunner/handlers/testhelper_test.go b/jobrunner/handlers/testhelper_test.go new file mode 100644 index 0000000..277591c --- /dev/null +++ b/jobrunner/handlers/testhelper_test.go @@ -0,0 +1,35 @@ +package handlers + +import ( + "net/http" + "strings" + "testing" + + "github.com/stretchr/testify/require" + + "forge.lthn.ai/core/go-scm/forge" +) + +// forgejoVersionResponse is the JSON response for /api/v1/version. +const forgejoVersionResponse = `{"version":"9.0.0"}` + +// withVersion wraps an HTTP handler to also serve the Forgejo version endpoint +// that the SDK calls during NewClient initialization. +func withVersion(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if strings.HasSuffix(r.URL.Path, "/version") { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(forgejoVersionResponse)) + return + } + next.ServeHTTP(w, r) + }) +} + +// newTestForgeClient creates a forge.Client pointing at the given test server URL. +func newTestForgeClient(t *testing.T, url string) *forge.Client { + t.Helper() + client, err := forge.New(url, "test-token") + require.NoError(t, err) + return client +} diff --git a/jobrunner/handlers/tick_parent.go b/jobrunner/handlers/tick_parent.go new file mode 100644 index 0000000..fa7db10 --- /dev/null +++ b/jobrunner/handlers/tick_parent.go @@ -0,0 +1,100 @@ +package handlers + +import ( + "context" + "fmt" + "strings" + "time" + + forgejosdk "codeberg.org/mvdkleijn/forgejo-sdk/forgejo/v2" + + "forge.lthn.ai/core/go-scm/forge" + "forge.lthn.ai/core/go-scm/jobrunner" +) + +// TickParentHandler ticks a child checkbox in the parent epic issue body +// after the child's PR has been merged. +type TickParentHandler struct { + forge *forge.Client +} + +// NewTickParentHandler creates a handler that ticks parent epic checkboxes. +func NewTickParentHandler(f *forge.Client) *TickParentHandler { + return &TickParentHandler{forge: f} +} + +// Name returns the handler identifier. +func (h *TickParentHandler) Name() string { + return "tick_parent" +} + +// Match returns true when the child PR has been merged. +func (h *TickParentHandler) Match(signal *jobrunner.PipelineSignal) bool { + return signal.PRState == "MERGED" +} + +// Execute fetches the epic body, replaces the unchecked checkbox for the +// child issue with a checked one, updates the epic, and closes the child issue. +func (h *TickParentHandler) Execute(ctx context.Context, signal *jobrunner.PipelineSignal) (*jobrunner.ActionResult, error) { + start := time.Now() + + // Fetch the epic issue body. + epic, err := h.forge.GetIssue(signal.RepoOwner, signal.RepoName, int64(signal.EpicNumber)) + if err != nil { + return nil, fmt.Errorf("tick_parent: fetch epic: %w", err) + } + + oldBody := epic.Body + unchecked := fmt.Sprintf("- [ ] #%d", signal.ChildNumber) + checked := fmt.Sprintf("- [x] #%d", signal.ChildNumber) + + if !strings.Contains(oldBody, unchecked) { + // Already ticked or not found -- nothing to do. + return &jobrunner.ActionResult{ + Action: "tick_parent", + RepoOwner: signal.RepoOwner, + RepoName: signal.RepoName, + PRNumber: signal.PRNumber, + Success: true, + Timestamp: time.Now(), + Duration: time.Since(start), + }, nil + } + + newBody := strings.Replace(oldBody, unchecked, checked, 1) + + // Update the epic body. + _, err = h.forge.EditIssue(signal.RepoOwner, signal.RepoName, int64(signal.EpicNumber), forgejosdk.EditIssueOption{ + Body: &newBody, + }) + if err != nil { + return &jobrunner.ActionResult{ + Action: "tick_parent", + RepoOwner: signal.RepoOwner, + RepoName: signal.RepoName, + PRNumber: signal.PRNumber, + Error: fmt.Sprintf("edit epic failed: %v", err), + Timestamp: time.Now(), + Duration: time.Since(start), + }, nil + } + + // Close the child issue. + err = h.forge.CloseIssue(signal.RepoOwner, signal.RepoName, int64(signal.ChildNumber)) + + result := &jobrunner.ActionResult{ + Action: "tick_parent", + RepoOwner: signal.RepoOwner, + RepoName: signal.RepoName, + PRNumber: signal.PRNumber, + Success: err == nil, + Timestamp: time.Now(), + Duration: time.Since(start), + } + + if err != nil { + result.Error = fmt.Sprintf("close child issue failed: %v", err) + } + + return result, nil +} diff --git a/jobrunner/handlers/tick_parent_test.go b/jobrunner/handlers/tick_parent_test.go new file mode 100644 index 0000000..4a8fd78 --- /dev/null +++ b/jobrunner/handlers/tick_parent_test.go @@ -0,0 +1,98 @@ +package handlers + +import ( + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "forge.lthn.ai/core/go-scm/jobrunner" +) + +func TestTickParent_Match_Good(t *testing.T) { + h := NewTickParentHandler(nil) + sig := &jobrunner.PipelineSignal{ + PRState: "MERGED", + } + assert.True(t, h.Match(sig)) +} + +func TestTickParent_Match_Bad_Open(t *testing.T) { + h := NewTickParentHandler(nil) + sig := &jobrunner.PipelineSignal{ + PRState: "OPEN", + } + assert.False(t, h.Match(sig)) +} + +func TestTickParent_Execute_Good(t *testing.T) { + epicBody := "## Tasks\n- [x] #1\n- [ ] #7\n- [ ] #8\n" + var editBody string + var closeCalled bool + + srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + path := r.URL.Path + method := r.Method + w.Header().Set("Content-Type", "application/json") + + switch { + // GET issue (fetch epic) + case method == http.MethodGet && strings.Contains(path, "/issues/42"): + _ = json.NewEncoder(w).Encode(map[string]any{ + "number": 42, + "body": epicBody, + "title": "Epic", + }) + + // PATCH issue (edit epic body) + case method == http.MethodPatch && strings.Contains(path, "/issues/42"): + b, _ := io.ReadAll(r.Body) + editBody = string(b) + _ = json.NewEncoder(w).Encode(map[string]any{ + "number": 42, + "body": editBody, + "title": "Epic", + }) + + // PATCH issue (close child — state: closed) + case method == http.MethodPatch && strings.Contains(path, "/issues/7"): + closeCalled = true + _ = json.NewEncoder(w).Encode(map[string]any{ + "number": 7, + "state": "closed", + }) + + default: + w.WriteHeader(http.StatusNotFound) + } + }))) + defer srv.Close() + + client := newTestForgeClient(t, srv.URL) + + h := NewTickParentHandler(client) + sig := &jobrunner.PipelineSignal{ + RepoOwner: "host-uk", + RepoName: "core-php", + EpicNumber: 42, + ChildNumber: 7, + PRNumber: 99, + PRState: "MERGED", + } + + result, err := h.Execute(context.Background(), sig) + require.NoError(t, err) + + assert.True(t, result.Success) + assert.Equal(t, "tick_parent", result.Action) + + // Verify the edit body contains the checked checkbox. + assert.Contains(t, editBody, "- [x] #7") + assert.True(t, closeCalled, "expected child issue to be closed") +} diff --git a/jobrunner/journal.go b/jobrunner/journal.go new file mode 100644 index 0000000..c09ffcf --- /dev/null +++ b/jobrunner/journal.go @@ -0,0 +1,170 @@ +package jobrunner + +import ( + "encoding/json" + "fmt" + "os" + "path/filepath" + "regexp" + "strings" + "sync" +) + +// validPathComponent matches safe repo owner/name characters (alphanumeric, hyphen, underscore, dot). +var validPathComponent = regexp.MustCompile(`^[a-zA-Z0-9][a-zA-Z0-9._-]*$`) + +// JournalEntry is a single line in the JSONL audit log. +type JournalEntry struct { + Timestamp string `json:"ts"` + Epic int `json:"epic"` + Child int `json:"child"` + PR int `json:"pr"` + Repo string `json:"repo"` + Action string `json:"action"` + Signals SignalSnapshot `json:"signals"` + Result ResultSnapshot `json:"result"` + Cycle int `json:"cycle"` +} + +// SignalSnapshot captures the structural state of a PR at the time of action. +type SignalSnapshot struct { + PRState string `json:"pr_state"` + IsDraft bool `json:"is_draft"` + CheckStatus string `json:"check_status"` + Mergeable string `json:"mergeable"` + ThreadsTotal int `json:"threads_total"` + ThreadsResolved int `json:"threads_resolved"` +} + +// ResultSnapshot captures the outcome of an action. +type ResultSnapshot struct { + Success bool `json:"success"` + Error string `json:"error,omitempty"` + DurationMs int64 `json:"duration_ms"` +} + +// Journal writes ActionResult entries to date-partitioned JSONL files. +type Journal struct { + baseDir string + mu sync.Mutex +} + +// NewJournal creates a new Journal rooted at baseDir. +func NewJournal(baseDir string) (*Journal, error) { + if baseDir == "" { + return nil, fmt.Errorf("journal base directory is required") + } + return &Journal{baseDir: baseDir}, nil +} + +// sanitizePathComponent validates a single path component (owner or repo name) +// to prevent path traversal attacks. It rejects "..", empty strings, paths +// containing separators, and any value outside the safe character set. +func sanitizePathComponent(name string) (string, error) { + // Reject empty or whitespace-only values. + if name == "" || strings.TrimSpace(name) == "" { + return "", fmt.Errorf("invalid path component: %q", name) + } + + // Reject inputs containing path separators (directory traversal attempt). + if strings.ContainsAny(name, `/\`) { + return "", fmt.Errorf("path component contains directory separator: %q", name) + } + + // Use filepath.Clean to normalize (e.g., collapse redundant dots). + clean := filepath.Clean(name) + + // Reject traversal components. + if clean == "." || clean == ".." { + return "", fmt.Errorf("invalid path component: %q", name) + } + + // Validate against the safe character set. + if !validPathComponent.MatchString(clean) { + return "", fmt.Errorf("path component contains invalid characters: %q", name) + } + + return clean, nil +} + +// Append writes a journal entry for the given signal and result. +func (j *Journal) Append(signal *PipelineSignal, result *ActionResult) error { + if signal == nil { + return fmt.Errorf("signal is required") + } + if result == nil { + return fmt.Errorf("result is required") + } + + entry := JournalEntry{ + Timestamp: result.Timestamp.UTC().Format("2006-01-02T15:04:05Z"), + Epic: signal.EpicNumber, + Child: signal.ChildNumber, + PR: signal.PRNumber, + Repo: signal.RepoFullName(), + Action: result.Action, + Signals: SignalSnapshot{ + PRState: signal.PRState, + IsDraft: signal.IsDraft, + CheckStatus: signal.CheckStatus, + Mergeable: signal.Mergeable, + ThreadsTotal: signal.ThreadsTotal, + ThreadsResolved: signal.ThreadsResolved, + }, + Result: ResultSnapshot{ + Success: result.Success, + Error: result.Error, + DurationMs: result.Duration.Milliseconds(), + }, + Cycle: result.Cycle, + } + + data, err := json.Marshal(entry) + if err != nil { + return fmt.Errorf("marshal journal entry: %w", err) + } + data = append(data, '\n') + + // Sanitize path components to prevent path traversal (CVE: issue #46). + owner, err := sanitizePathComponent(signal.RepoOwner) + if err != nil { + return fmt.Errorf("invalid repo owner: %w", err) + } + repo, err := sanitizePathComponent(signal.RepoName) + if err != nil { + return fmt.Errorf("invalid repo name: %w", err) + } + + date := result.Timestamp.UTC().Format("2006-01-02") + dir := filepath.Join(j.baseDir, owner, repo) + + // Resolve to absolute path and verify it stays within baseDir. + absBase, err := filepath.Abs(j.baseDir) + if err != nil { + return fmt.Errorf("resolve base directory: %w", err) + } + absDir, err := filepath.Abs(dir) + if err != nil { + return fmt.Errorf("resolve journal directory: %w", err) + } + if !strings.HasPrefix(absDir, absBase+string(filepath.Separator)) { + return fmt.Errorf("journal path %q escapes base directory %q", absDir, absBase) + } + + j.mu.Lock() + defer j.mu.Unlock() + + if err := os.MkdirAll(dir, 0o755); err != nil { + return fmt.Errorf("create journal directory: %w", err) + } + + path := filepath.Join(dir, date+".jsonl") + f, err := os.OpenFile(path, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0o644) + if err != nil { + return fmt.Errorf("open journal file: %w", err) + } + defer func() { _ = f.Close() }() + + _, err = f.Write(data) + return err +} diff --git a/jobrunner/journal_test.go b/jobrunner/journal_test.go new file mode 100644 index 0000000..a17a88b --- /dev/null +++ b/jobrunner/journal_test.go @@ -0,0 +1,263 @@ +package jobrunner + +import ( + "bufio" + "encoding/json" + "os" + "path/filepath" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestJournal_Append_Good(t *testing.T) { + dir := t.TempDir() + + j, err := NewJournal(dir) + require.NoError(t, err) + + ts := time.Date(2026, 2, 5, 14, 30, 0, 0, time.UTC) + + signal := &PipelineSignal{ + EpicNumber: 10, + ChildNumber: 3, + PRNumber: 55, + RepoOwner: "host-uk", + RepoName: "core-tenant", + PRState: "OPEN", + IsDraft: false, + Mergeable: "MERGEABLE", + CheckStatus: "SUCCESS", + ThreadsTotal: 2, + ThreadsResolved: 1, + LastCommitSHA: "abc123", + LastCommitAt: ts, + LastReviewAt: ts, + } + + result := &ActionResult{ + Action: "merge", + RepoOwner: "host-uk", + RepoName: "core-tenant", + EpicNumber: 10, + ChildNumber: 3, + PRNumber: 55, + Success: true, + Timestamp: ts, + Duration: 1200 * time.Millisecond, + Cycle: 1, + } + + err = j.Append(signal, result) + require.NoError(t, err) + + // Read the file back. + expectedPath := filepath.Join(dir, "host-uk", "core-tenant", "2026-02-05.jsonl") + f, err := os.Open(expectedPath) + require.NoError(t, err) + defer func() { _ = f.Close() }() + + scanner := bufio.NewScanner(f) + require.True(t, scanner.Scan(), "expected at least one line in JSONL file") + + var entry JournalEntry + err = json.Unmarshal(scanner.Bytes(), &entry) + require.NoError(t, err) + + assert.Equal(t, "2026-02-05T14:30:00Z", entry.Timestamp) + assert.Equal(t, 10, entry.Epic) + assert.Equal(t, 3, entry.Child) + assert.Equal(t, 55, entry.PR) + assert.Equal(t, "host-uk/core-tenant", entry.Repo) + assert.Equal(t, "merge", entry.Action) + assert.Equal(t, 1, entry.Cycle) + + // Verify signal snapshot. + assert.Equal(t, "OPEN", entry.Signals.PRState) + assert.Equal(t, false, entry.Signals.IsDraft) + assert.Equal(t, "SUCCESS", entry.Signals.CheckStatus) + assert.Equal(t, "MERGEABLE", entry.Signals.Mergeable) + assert.Equal(t, 2, entry.Signals.ThreadsTotal) + assert.Equal(t, 1, entry.Signals.ThreadsResolved) + + // Verify result snapshot. + assert.Equal(t, true, entry.Result.Success) + assert.Equal(t, "", entry.Result.Error) + assert.Equal(t, int64(1200), entry.Result.DurationMs) + + // Append a second entry and verify two lines exist. + result2 := &ActionResult{ + Action: "comment", + RepoOwner: "host-uk", + RepoName: "core-tenant", + Success: false, + Error: "rate limited", + Timestamp: ts, + Duration: 50 * time.Millisecond, + Cycle: 2, + } + err = j.Append(signal, result2) + require.NoError(t, err) + + data, err := os.ReadFile(expectedPath) + require.NoError(t, err) + + lines := 0 + sc := bufio.NewScanner(strings.NewReader(string(data))) + for sc.Scan() { + lines++ + } + assert.Equal(t, 2, lines, "expected two JSONL lines after two appends") +} + +func TestJournal_Append_Bad_PathTraversal(t *testing.T) { + dir := t.TempDir() + + j, err := NewJournal(dir) + require.NoError(t, err) + + ts := time.Now() + + tests := []struct { + name string + repoOwner string + repoName string + wantErr string + }{ + { + name: "dotdot owner", + repoOwner: "..", + repoName: "core", + wantErr: "invalid repo owner", + }, + { + name: "dotdot repo", + repoOwner: "host-uk", + repoName: "../../etc/cron.d", + wantErr: "invalid repo name", + }, + { + name: "slash in owner", + repoOwner: "../etc", + repoName: "core", + wantErr: "invalid repo owner", + }, + { + name: "absolute path in repo", + repoOwner: "host-uk", + repoName: "/etc/passwd", + wantErr: "invalid repo name", + }, + { + name: "empty owner", + repoOwner: "", + repoName: "core", + wantErr: "invalid repo owner", + }, + { + name: "empty repo", + repoOwner: "host-uk", + repoName: "", + wantErr: "invalid repo name", + }, + { + name: "dot only owner", + repoOwner: ".", + repoName: "core", + wantErr: "invalid repo owner", + }, + { + name: "spaces only owner", + repoOwner: " ", + repoName: "core", + wantErr: "invalid repo owner", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + signal := &PipelineSignal{ + RepoOwner: tc.repoOwner, + RepoName: tc.repoName, + } + result := &ActionResult{ + Action: "merge", + Timestamp: ts, + } + + err := j.Append(signal, result) + require.Error(t, err) + assert.Contains(t, err.Error(), tc.wantErr) + }) + } +} + +func TestJournal_Append_Good_ValidNames(t *testing.T) { + dir := t.TempDir() + + j, err := NewJournal(dir) + require.NoError(t, err) + + ts := time.Date(2026, 2, 5, 14, 30, 0, 0, time.UTC) + + // Verify valid names with dots, hyphens, underscores all work. + validNames := []struct { + owner string + repo string + }{ + {"host-uk", "core"}, + {"my_org", "my_repo"}, + {"org.name", "repo.v2"}, + {"a", "b"}, + {"Org-123", "Repo_456.go"}, + } + + for _, vn := range validNames { + signal := &PipelineSignal{ + RepoOwner: vn.owner, + RepoName: vn.repo, + } + result := &ActionResult{ + Action: "test", + Timestamp: ts, + } + + err := j.Append(signal, result) + assert.NoError(t, err, "expected valid name pair %s/%s to succeed", vn.owner, vn.repo) + } +} + +func TestJournal_Append_Bad_NilSignal(t *testing.T) { + dir := t.TempDir() + + j, err := NewJournal(dir) + require.NoError(t, err) + + result := &ActionResult{ + Action: "merge", + Timestamp: time.Now(), + } + + err = j.Append(nil, result) + require.Error(t, err) + assert.Contains(t, err.Error(), "signal is required") +} + +func TestJournal_Append_Bad_NilResult(t *testing.T) { + dir := t.TempDir() + + j, err := NewJournal(dir) + require.NoError(t, err) + + signal := &PipelineSignal{ + RepoOwner: "host-uk", + RepoName: "core-php", + } + + err = j.Append(signal, nil) + require.Error(t, err) + assert.Contains(t, err.Error(), "result is required") +} diff --git a/jobrunner/poller.go b/jobrunner/poller.go new file mode 100644 index 0000000..be6b213 --- /dev/null +++ b/jobrunner/poller.go @@ -0,0 +1,195 @@ +package jobrunner + +import ( + "context" + "sync" + "time" + + "forge.lthn.ai/core/go/pkg/log" +) + +// PollerConfig configures a Poller. +type PollerConfig struct { + Sources []JobSource + Handlers []JobHandler + Journal *Journal + PollInterval time.Duration + DryRun bool +} + +// Poller discovers signals from sources and dispatches them to handlers. +type Poller struct { + mu sync.RWMutex + sources []JobSource + handlers []JobHandler + journal *Journal + interval time.Duration + dryRun bool + cycle int +} + +// NewPoller creates a Poller from the given config. +func NewPoller(cfg PollerConfig) *Poller { + interval := cfg.PollInterval + if interval <= 0 { + interval = 60 * time.Second + } + + return &Poller{ + sources: cfg.Sources, + handlers: cfg.Handlers, + journal: cfg.Journal, + interval: interval, + dryRun: cfg.DryRun, + } +} + +// Cycle returns the number of completed poll-dispatch cycles. +func (p *Poller) Cycle() int { + p.mu.RLock() + defer p.mu.RUnlock() + return p.cycle +} + +// DryRun returns whether dry-run mode is enabled. +func (p *Poller) DryRun() bool { + p.mu.RLock() + defer p.mu.RUnlock() + return p.dryRun +} + +// SetDryRun enables or disables dry-run mode. +func (p *Poller) SetDryRun(v bool) { + p.mu.Lock() + p.dryRun = v + p.mu.Unlock() +} + +// AddSource appends a source to the poller. +func (p *Poller) AddSource(s JobSource) { + p.mu.Lock() + p.sources = append(p.sources, s) + p.mu.Unlock() +} + +// AddHandler appends a handler to the poller. +func (p *Poller) AddHandler(h JobHandler) { + p.mu.Lock() + p.handlers = append(p.handlers, h) + p.mu.Unlock() +} + +// Run starts a blocking poll-dispatch loop. It runs one cycle immediately, +// then repeats on each tick of the configured interval until the context +// is cancelled. +func (p *Poller) Run(ctx context.Context) error { + if err := p.RunOnce(ctx); err != nil { + return err + } + + ticker := time.NewTicker(p.interval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return ctx.Err() + case <-ticker.C: + if err := p.RunOnce(ctx); err != nil { + return err + } + } + } +} + +// RunOnce performs a single poll-dispatch cycle: iterate sources, poll each, +// find the first matching handler for each signal, and execute it. +func (p *Poller) RunOnce(ctx context.Context) error { + p.mu.Lock() + p.cycle++ + cycle := p.cycle + dryRun := p.dryRun + sources := make([]JobSource, len(p.sources)) + copy(sources, p.sources) + handlers := make([]JobHandler, len(p.handlers)) + copy(handlers, p.handlers) + p.mu.Unlock() + + log.Info("poller cycle starting", "cycle", cycle, "sources", len(sources), "handlers", len(handlers)) + + for _, src := range sources { + signals, err := src.Poll(ctx) + if err != nil { + log.Error("poll failed", "source", src.Name(), "err", err) + continue + } + + log.Info("polled source", "source", src.Name(), "signals", len(signals)) + + for _, sig := range signals { + handler := p.findHandler(handlers, sig) + if handler == nil { + log.Debug("no matching handler", "epic", sig.EpicNumber, "child", sig.ChildNumber) + continue + } + + if dryRun { + log.Info("dry-run: would execute", + "handler", handler.Name(), + "epic", sig.EpicNumber, + "child", sig.ChildNumber, + "pr", sig.PRNumber, + ) + continue + } + + start := time.Now() + result, err := handler.Execute(ctx, sig) + elapsed := time.Since(start) + + if err != nil { + log.Error("handler execution failed", + "handler", handler.Name(), + "epic", sig.EpicNumber, + "child", sig.ChildNumber, + "err", err, + ) + continue + } + + result.Cycle = cycle + result.EpicNumber = sig.EpicNumber + result.ChildNumber = sig.ChildNumber + result.Duration = elapsed + + if p.journal != nil { + if jErr := p.journal.Append(sig, result); jErr != nil { + log.Error("journal append failed", "err", jErr) + } + } + + if rErr := src.Report(ctx, result); rErr != nil { + log.Error("source report failed", "source", src.Name(), "err", rErr) + } + + log.Info("handler executed", + "handler", handler.Name(), + "action", result.Action, + "success", result.Success, + "duration", elapsed, + ) + } + } + + return nil +} + +// findHandler returns the first handler that matches the signal, or nil. +func (p *Poller) findHandler(handlers []JobHandler, sig *PipelineSignal) JobHandler { + for _, h := range handlers { + if h.Match(sig) { + return h + } + } + return nil +} diff --git a/jobrunner/poller_test.go b/jobrunner/poller_test.go new file mode 100644 index 0000000..1d3a908 --- /dev/null +++ b/jobrunner/poller_test.go @@ -0,0 +1,307 @@ +package jobrunner + +import ( + "context" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// --- Mock source --- + +type mockSource struct { + name string + signals []*PipelineSignal + reports []*ActionResult + mu sync.Mutex +} + +func (m *mockSource) Name() string { return m.name } + +func (m *mockSource) Poll(_ context.Context) ([]*PipelineSignal, error) { + m.mu.Lock() + defer m.mu.Unlock() + return m.signals, nil +} + +func (m *mockSource) Report(_ context.Context, result *ActionResult) error { + m.mu.Lock() + defer m.mu.Unlock() + m.reports = append(m.reports, result) + return nil +} + +// --- Mock handler --- + +type mockHandler struct { + name string + matchFn func(*PipelineSignal) bool + executed []*PipelineSignal + mu sync.Mutex +} + +func (m *mockHandler) Name() string { return m.name } + +func (m *mockHandler) Match(sig *PipelineSignal) bool { + if m.matchFn != nil { + return m.matchFn(sig) + } + return true +} + +func (m *mockHandler) Execute(_ context.Context, sig *PipelineSignal) (*ActionResult, error) { + m.mu.Lock() + defer m.mu.Unlock() + m.executed = append(m.executed, sig) + return &ActionResult{ + Action: m.name, + RepoOwner: sig.RepoOwner, + RepoName: sig.RepoName, + PRNumber: sig.PRNumber, + Success: true, + Timestamp: time.Now(), + }, nil +} + +func TestPoller_RunOnce_Good(t *testing.T) { + sig := &PipelineSignal{ + EpicNumber: 1, + ChildNumber: 2, + PRNumber: 10, + RepoOwner: "host-uk", + RepoName: "core-php", + PRState: "OPEN", + CheckStatus: "SUCCESS", + Mergeable: "MERGEABLE", + } + + src := &mockSource{ + name: "test-source", + signals: []*PipelineSignal{sig}, + } + + handler := &mockHandler{ + name: "test-handler", + matchFn: func(s *PipelineSignal) bool { + return s.PRNumber == 10 + }, + } + + p := NewPoller(PollerConfig{ + Sources: []JobSource{src}, + Handlers: []JobHandler{handler}, + }) + + err := p.RunOnce(context.Background()) + require.NoError(t, err) + + // Handler should have been called with our signal. + handler.mu.Lock() + defer handler.mu.Unlock() + require.Len(t, handler.executed, 1) + assert.Equal(t, 10, handler.executed[0].PRNumber) + + // Source should have received a report. + src.mu.Lock() + defer src.mu.Unlock() + require.Len(t, src.reports, 1) + assert.Equal(t, "test-handler", src.reports[0].Action) + assert.True(t, src.reports[0].Success) + assert.Equal(t, 1, src.reports[0].Cycle) + assert.Equal(t, 1, src.reports[0].EpicNumber) + assert.Equal(t, 2, src.reports[0].ChildNumber) + + // Cycle counter should have incremented. + assert.Equal(t, 1, p.Cycle()) +} + +func TestPoller_RunOnce_Good_NoSignals(t *testing.T) { + src := &mockSource{ + name: "empty-source", + signals: nil, + } + + handler := &mockHandler{ + name: "unused-handler", + } + + p := NewPoller(PollerConfig{ + Sources: []JobSource{src}, + Handlers: []JobHandler{handler}, + }) + + err := p.RunOnce(context.Background()) + require.NoError(t, err) + + // Handler should not have been called. + handler.mu.Lock() + defer handler.mu.Unlock() + assert.Empty(t, handler.executed) + + // Source should not have received reports. + src.mu.Lock() + defer src.mu.Unlock() + assert.Empty(t, src.reports) + + assert.Equal(t, 1, p.Cycle()) +} + +func TestPoller_RunOnce_Good_NoMatchingHandler(t *testing.T) { + sig := &PipelineSignal{ + EpicNumber: 5, + ChildNumber: 8, + PRNumber: 42, + RepoOwner: "host-uk", + RepoName: "core-tenant", + PRState: "OPEN", + } + + src := &mockSource{ + name: "test-source", + signals: []*PipelineSignal{sig}, + } + + handler := &mockHandler{ + name: "picky-handler", + matchFn: func(s *PipelineSignal) bool { + return false // never matches + }, + } + + p := NewPoller(PollerConfig{ + Sources: []JobSource{src}, + Handlers: []JobHandler{handler}, + }) + + err := p.RunOnce(context.Background()) + require.NoError(t, err) + + // Handler should not have been called. + handler.mu.Lock() + defer handler.mu.Unlock() + assert.Empty(t, handler.executed) + + // Source should not have received reports (no action taken). + src.mu.Lock() + defer src.mu.Unlock() + assert.Empty(t, src.reports) +} + +func TestPoller_RunOnce_Good_DryRun(t *testing.T) { + sig := &PipelineSignal{ + EpicNumber: 1, + ChildNumber: 3, + PRNumber: 20, + RepoOwner: "host-uk", + RepoName: "core-admin", + PRState: "OPEN", + CheckStatus: "SUCCESS", + Mergeable: "MERGEABLE", + } + + src := &mockSource{ + name: "test-source", + signals: []*PipelineSignal{sig}, + } + + handler := &mockHandler{ + name: "merge-handler", + matchFn: func(s *PipelineSignal) bool { + return true + }, + } + + p := NewPoller(PollerConfig{ + Sources: []JobSource{src}, + Handlers: []JobHandler{handler}, + DryRun: true, + }) + + assert.True(t, p.DryRun()) + + err := p.RunOnce(context.Background()) + require.NoError(t, err) + + // Handler should NOT have been called in dry-run mode. + handler.mu.Lock() + defer handler.mu.Unlock() + assert.Empty(t, handler.executed) + + // Source should not have received reports. + src.mu.Lock() + defer src.mu.Unlock() + assert.Empty(t, src.reports) +} + +func TestPoller_SetDryRun_Good(t *testing.T) { + p := NewPoller(PollerConfig{}) + + assert.False(t, p.DryRun()) + p.SetDryRun(true) + assert.True(t, p.DryRun()) + p.SetDryRun(false) + assert.False(t, p.DryRun()) +} + +func TestPoller_AddSourceAndHandler_Good(t *testing.T) { + p := NewPoller(PollerConfig{}) + + sig := &PipelineSignal{ + EpicNumber: 1, + ChildNumber: 1, + PRNumber: 5, + RepoOwner: "host-uk", + RepoName: "core-php", + PRState: "OPEN", + } + + src := &mockSource{ + name: "added-source", + signals: []*PipelineSignal{sig}, + } + + handler := &mockHandler{ + name: "added-handler", + matchFn: func(s *PipelineSignal) bool { return true }, + } + + p.AddSource(src) + p.AddHandler(handler) + + err := p.RunOnce(context.Background()) + require.NoError(t, err) + + handler.mu.Lock() + defer handler.mu.Unlock() + require.Len(t, handler.executed, 1) + assert.Equal(t, 5, handler.executed[0].PRNumber) +} + +func TestPoller_Run_Good(t *testing.T) { + src := &mockSource{ + name: "tick-source", + signals: nil, + } + + p := NewPoller(PollerConfig{ + Sources: []JobSource{src}, + PollInterval: 50 * time.Millisecond, + }) + + ctx, cancel := context.WithTimeout(context.Background(), 180*time.Millisecond) + defer cancel() + + err := p.Run(ctx) + assert.ErrorIs(t, err, context.DeadlineExceeded) + + // Should have completed at least 2 cycles (one immediate + at least one tick). + assert.GreaterOrEqual(t, p.Cycle(), 2) +} + +func TestPoller_DefaultInterval_Good(t *testing.T) { + p := NewPoller(PollerConfig{}) + assert.Equal(t, 60*time.Second, p.interval) +} diff --git a/jobrunner/types.go b/jobrunner/types.go new file mode 100644 index 0000000..ce51caf --- /dev/null +++ b/jobrunner/types.go @@ -0,0 +1,72 @@ +package jobrunner + +import ( + "context" + "time" +) + +// PipelineSignal is the structural snapshot of a child issue/PR. +// Carries structural state plus issue title/body for dispatch prompts. +type PipelineSignal struct { + EpicNumber int + ChildNumber int + PRNumber int + RepoOwner string + RepoName string + PRState string // OPEN, MERGED, CLOSED + IsDraft bool + Mergeable string // MERGEABLE, CONFLICTING, UNKNOWN + CheckStatus string // SUCCESS, FAILURE, PENDING + ThreadsTotal int + ThreadsResolved int + LastCommitSHA string + LastCommitAt time.Time + LastReviewAt time.Time + NeedsCoding bool // true if child has no PR (work not started) + Assignee string // issue assignee username (for dispatch) + IssueTitle string // child issue title (for dispatch prompt) + IssueBody string // child issue body (for dispatch prompt) + Type string // signal type (e.g., "agent_completion") + Success bool // agent completion success flag + Error string // agent error message + Message string // agent completion message +} + +// RepoFullName returns "owner/repo". +func (s *PipelineSignal) RepoFullName() string { + return s.RepoOwner + "/" + s.RepoName +} + +// HasUnresolvedThreads returns true if there are unresolved review threads. +func (s *PipelineSignal) HasUnresolvedThreads() bool { + return s.ThreadsTotal > s.ThreadsResolved +} + +// ActionResult carries the outcome of a handler execution. +type ActionResult struct { + Action string `json:"action"` + RepoOwner string `json:"repo_owner"` + RepoName string `json:"repo_name"` + EpicNumber int `json:"epic"` + ChildNumber int `json:"child"` + PRNumber int `json:"pr"` + Success bool `json:"success"` + Error string `json:"error,omitempty"` + Timestamp time.Time `json:"ts"` + Duration time.Duration `json:"duration_ms"` + Cycle int `json:"cycle"` +} + +// JobSource discovers actionable work from an external system. +type JobSource interface { + Name() string + Poll(ctx context.Context) ([]*PipelineSignal, error) + Report(ctx context.Context, result *ActionResult) error +} + +// JobHandler processes a single pipeline signal. +type JobHandler interface { + Name() string + Match(signal *PipelineSignal) bool + Execute(ctx context.Context, signal *PipelineSignal) (*ActionResult, error) +} diff --git a/jobrunner/types_test.go b/jobrunner/types_test.go new file mode 100644 index 0000000..c81a840 --- /dev/null +++ b/jobrunner/types_test.go @@ -0,0 +1,98 @@ +package jobrunner + +import ( + "encoding/json" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestPipelineSignal_RepoFullName_Good(t *testing.T) { + sig := &PipelineSignal{ + RepoOwner: "host-uk", + RepoName: "core-php", + } + assert.Equal(t, "host-uk/core-php", sig.RepoFullName()) +} + +func TestPipelineSignal_HasUnresolvedThreads_Good(t *testing.T) { + sig := &PipelineSignal{ + ThreadsTotal: 5, + ThreadsResolved: 3, + } + assert.True(t, sig.HasUnresolvedThreads()) +} + +func TestPipelineSignal_HasUnresolvedThreads_Bad_AllResolved(t *testing.T) { + sig := &PipelineSignal{ + ThreadsTotal: 4, + ThreadsResolved: 4, + } + assert.False(t, sig.HasUnresolvedThreads()) + + // Also verify zero threads is not unresolved. + sigZero := &PipelineSignal{ + ThreadsTotal: 0, + ThreadsResolved: 0, + } + assert.False(t, sigZero.HasUnresolvedThreads()) +} + +func TestActionResult_JSON_Good(t *testing.T) { + ts := time.Date(2026, 2, 5, 12, 0, 0, 0, time.UTC) + result := &ActionResult{ + Action: "merge", + RepoOwner: "host-uk", + RepoName: "core-tenant", + EpicNumber: 42, + ChildNumber: 7, + PRNumber: 99, + Success: true, + Timestamp: ts, + Duration: 1500 * time.Millisecond, + Cycle: 3, + } + + data, err := json.Marshal(result) + require.NoError(t, err) + + var decoded map[string]any + err = json.Unmarshal(data, &decoded) + require.NoError(t, err) + + assert.Equal(t, "merge", decoded["action"]) + assert.Equal(t, "host-uk", decoded["repo_owner"]) + assert.Equal(t, "core-tenant", decoded["repo_name"]) + assert.Equal(t, float64(42), decoded["epic"]) + assert.Equal(t, float64(7), decoded["child"]) + assert.Equal(t, float64(99), decoded["pr"]) + assert.Equal(t, true, decoded["success"]) + assert.Equal(t, float64(3), decoded["cycle"]) + + // Error field should be omitted when empty. + _, hasError := decoded["error"] + assert.False(t, hasError, "error field should be omitted when empty") + + // Verify round-trip with error field present. + resultWithErr := &ActionResult{ + Action: "merge", + RepoOwner: "host-uk", + RepoName: "core-tenant", + Success: false, + Error: "checks failing", + Timestamp: ts, + Duration: 200 * time.Millisecond, + Cycle: 1, + } + data2, err := json.Marshal(resultWithErr) + require.NoError(t, err) + + var decoded2 map[string]any + err = json.Unmarshal(data2, &decoded2) + require.NoError(t, err) + + assert.Equal(t, "checks failing", decoded2["error"]) + assert.Equal(t, false, decoded2["success"]) +}