diff --git a/.github/workflows/auto-merge.yml b/.github/workflows/auto-merge.yml new file mode 100644 index 00000000..ec3cf86b --- /dev/null +++ b/.github/workflows/auto-merge.yml @@ -0,0 +1,40 @@ +name: Auto Merge + +on: + pull_request: + types: [opened, reopened, ready_for_review] + +permissions: + contents: write + pull-requests: write + +jobs: + auto-merge: + if: "!github.event.pull_request.draft" + runs-on: ubuntu-latest + steps: + - name: Check org membership and enable auto-merge + uses: actions/github-script@v7 + env: + GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} + PR_NUMBER: ${{ github.event.pull_request.number }} + with: + script: | + const { owner, repo } = context.repo; + const author = context.payload.pull_request.user.login; + + try { + await github.rest.orgs.checkMembershipForUser({ + org: owner, + username: author, + }); + } catch { + core.info(`${author} is not an org member — skipping auto-merge`); + return; + } + + await exec.exec('gh', [ + 'pr', 'merge', process.env.PR_NUMBER, + '--auto', '--squash', + ]); + core.info(`Auto-merge enabled for #${process.env.PR_NUMBER}`); diff --git a/.github/workflows/pr-gate.yml b/.github/workflows/pr-gate.yml new file mode 100644 index 00000000..299f186b --- /dev/null +++ b/.github/workflows/pr-gate.yml @@ -0,0 +1,42 @@ +name: PR Gate + +on: + pull_request_target: + types: [opened, synchronize, reopened, labeled] + +permissions: + contents: read + +jobs: + org-gate: + runs-on: ubuntu-latest + steps: + - name: Check org membership or approval label + uses: actions/github-script@v7 + with: + script: | + const { owner, repo } = context.repo; + const author = context.payload.pull_request.user.login; + + // Check if author is an org member + try { + await github.rest.orgs.checkMembershipForUser({ + org: owner, + username: author, + }); + core.info(`${author} is an org member — gate passed`); + return; + } catch { + core.info(`${author} is not an org member — checking for label`); + } + + // Check for external-approved label + const labels = context.payload.pull_request.labels.map(l => l.name); + if (labels.includes('external-approved')) { + core.info('external-approved label present — gate passed'); + return; + } + + core.setFailed( + `External PR from ${author} requires an org member to add the "external-approved" label before merge.` + ); diff --git a/core-test b/core-test deleted file mode 100755 index 65048b84..00000000 Binary files a/core-test and /dev/null differ diff --git a/internal/cmd/ci/cmd_init.go b/internal/cmd/ci/cmd_init.go index 59e4958c..aa7d022c 100644 --- a/internal/cmd/ci/cmd_init.go +++ b/internal/cmd/ci/cmd_init.go @@ -5,6 +5,7 @@ import ( "github.com/host-uk/core/pkg/cli" "github.com/host-uk/core/pkg/i18n" + "github.com/host-uk/core/pkg/io" "github.com/host-uk/core/pkg/release" ) @@ -17,14 +18,14 @@ func runCIReleaseInit() error { cli.Print("%s %s\n\n", releaseDimStyle.Render(i18n.Label("init")), i18n.T("cmd.ci.init.initializing")) // Check if already initialized - if release.ConfigExists(cwd) { + if release.ConfigExists(io.Local, cwd) { cli.Text(i18n.T("cmd.ci.init.already_initialized")) return nil } // Create release config cfg := release.DefaultConfig() - if err := release.WriteConfig(cfg, cwd); err != nil { + if err := release.WriteConfig(io.Local, cfg, cwd); err != nil { return cli.Err("%s: %w", i18n.T("i18n.fail.create", "config"), err) } diff --git a/internal/cmd/ci/cmd_publish.go b/internal/cmd/ci/cmd_publish.go index 23b0c4ef..4dc73c2e 100644 --- a/internal/cmd/ci/cmd_publish.go +++ b/internal/cmd/ci/cmd_publish.go @@ -7,6 +7,7 @@ import ( "github.com/host-uk/core/pkg/cli" "github.com/host-uk/core/pkg/i18n" + "github.com/host-uk/core/pkg/io" "github.com/host-uk/core/pkg/release" ) @@ -22,7 +23,7 @@ func runCIPublish(dryRun bool, version string, draft, prerelease bool) error { } // Load configuration - cfg, err := release.LoadConfig(projectDir) + cfg, err := release.LoadConfig(io.Local, projectDir) if err != nil { return cli.WrapVerb(err, "load", "config") } diff --git a/internal/cmd/dev/cmd_vm.go b/internal/cmd/dev/cmd_vm.go index 71a4ac23..52ef2104 100644 --- a/internal/cmd/dev/cmd_vm.go +++ b/internal/cmd/dev/cmd_vm.go @@ -9,6 +9,7 @@ import ( "github.com/host-uk/core/pkg/cli" "github.com/host-uk/core/pkg/devops" "github.com/host-uk/core/pkg/i18n" + "github.com/host-uk/core/pkg/io" ) // addVMCommands adds the dev environment VM commands to the dev parent command. @@ -40,7 +41,7 @@ func addVMInstallCommand(parent *cli.Command) { } func runVMInstall() error { - d, err := devops.New() + d, err := devops.New(io.Local) if err != nil { return err } @@ -112,7 +113,7 @@ func addVMBootCommand(parent *cli.Command) { } func runVMBoot(memory, cpus int, fresh bool) error { - d, err := devops.New() + d, err := devops.New(io.Local) if err != nil { return err } @@ -163,7 +164,7 @@ func addVMStopCommand(parent *cli.Command) { } func runVMStop() error { - d, err := devops.New() + d, err := devops.New(io.Local) if err != nil { return err } @@ -204,7 +205,7 @@ func addVMStatusCommand(parent *cli.Command) { } func runVMStatus() error { - d, err := devops.New() + d, err := devops.New(io.Local) if err != nil { return err } @@ -283,7 +284,7 @@ func addVMShellCommand(parent *cli.Command) { } func runVMShell(console bool, command []string) error { - d, err := devops.New() + d, err := devops.New(io.Local) if err != nil { return err } @@ -321,7 +322,7 @@ func addVMServeCommand(parent *cli.Command) { } func runVMServe(port int, path string) error { - d, err := devops.New() + d, err := devops.New(io.Local) if err != nil { return err } @@ -360,7 +361,7 @@ func addVMTestCommand(parent *cli.Command) { } func runVMTest(name string, command []string) error { - d, err := devops.New() + d, err := devops.New(io.Local) if err != nil { return err } @@ -405,7 +406,7 @@ func addVMClaudeCommand(parent *cli.Command) { } func runVMClaude(noAuth bool, model string, authFlags []string) error { - d, err := devops.New() + d, err := devops.New(io.Local) if err != nil { return err } @@ -445,7 +446,7 @@ func addVMUpdateCommand(parent *cli.Command) { } func runVMUpdate(apply bool) error { - d, err := devops.New() + d, err := devops.New(io.Local) if err != nil { return err } diff --git a/internal/cmd/go/cmd_qa.go b/internal/cmd/go/cmd_qa.go index 2ac1dfc5..ba086ee4 100644 --- a/internal/cmd/go/cmd_qa.go +++ b/internal/cmd/go/cmd_qa.go @@ -6,6 +6,7 @@ import ( "fmt" "os" "os/exec" + "regexp" "strings" "time" @@ -147,6 +148,7 @@ type CheckResult struct { Duration string `json:"duration"` Error string `json:"error,omitempty"` Output string `json:"output,omitempty"` + FixHint string `json:"fix_hint,omitempty"` } func runGoQA(cmd *cli.Command, args []string) error { @@ -218,6 +220,7 @@ func runGoQA(cmd *cli.Command, args []string) error { if qaVerbose { result.Output = output } + result.FixHint = fixHintFor(check.Name, output) failed++ if !qaJSON && !qaQuiet { @@ -225,6 +228,9 @@ func runGoQA(cmd *cli.Command, args []string) error { if qaVerbose && output != "" { cli.Text(output) } + if result.FixHint != "" { + cli.Hint("fix", result.FixHint) + } } if qaFailFast { @@ -260,6 +266,7 @@ func runGoQA(cmd *cli.Command, args []string) error { if !qaJSON && !qaQuiet { cli.Print(" %s Coverage %.1f%% below threshold %.1f%%\n", cli.ErrorStyle.Render(cli.Glyph(":cross:")), cov, qaThreshold) + cli.Hint("fix", "Run 'core go cov --open' to see uncovered lines, then add tests.") } } } @@ -436,6 +443,47 @@ func buildCheck(name string) QACheck { } } +// fixHintFor returns an actionable fix instruction for a given check failure. +func fixHintFor(checkName, output string) string { + switch checkName { + case "format", "fmt": + return "Run 'core go qa fmt --fix' to auto-format." + case "vet": + return "Fix the issues reported by go vet — typically genuine bugs." + case "lint": + return "Run 'core go qa lint --fix' for auto-fixable issues." + case "test": + if name := extractFailingTest(output); name != "" { + return fmt.Sprintf("Run 'go test -run %s -v ./...' to debug.", name) + } + return "Run 'go test -run -v ./path/' to debug." + case "race": + return "Data race detected. Add mutex, channel, or atomic to synchronise shared state." + case "bench": + return "Benchmark regression. Run 'go test -bench=. -benchmem' to reproduce." + case "vuln": + return "Run 'govulncheck ./...' for details. Update affected deps with 'go get -u'." + case "sec": + return "Review gosec findings. Common fixes: validate inputs, parameterised queries." + case "fuzz": + return "Add a regression test for the crashing input in testdata/fuzz//." + case "docblock": + return "Add doc comments to exported symbols: '// Name does X.' before each declaration." + default: + return "" + } +} + +var failTestRe = regexp.MustCompile(`--- FAIL: (\w+)`) + +// extractFailingTest parses the first failing test name from go test output. +func extractFailingTest(output string) string { + if m := failTestRe.FindStringSubmatch(output); len(m) > 1 { + return m[1] + } + return "" +} + func runCheckCapture(ctx context.Context, dir string, check QACheck) (string, error) { // Handle internal checks if check.Command == "_internal_" { @@ -528,8 +576,8 @@ func runCoverage(ctx context.Context, dir string) (float64, error) { func runInternalCheck(check QACheck) (string, error) { switch check.Name { case "fuzz": - // Short burst fuzz in QA (5s per target) - duration := 5 * time.Second + // Short burst fuzz in QA (3s per target) + duration := 3 * time.Second if qaTimeout > 0 && qaTimeout < 30*time.Second { duration = 2 * time.Second } diff --git a/internal/cmd/pkgcmd/cmd_search.go b/internal/cmd/pkgcmd/cmd_search.go index c672ca72..5b34cbc1 100644 --- a/internal/cmd/pkgcmd/cmd_search.go +++ b/internal/cmd/pkgcmd/cmd_search.go @@ -73,7 +73,7 @@ func runPkgSearch(org, pattern, repoType string, limit int, refresh bool) error cacheDir = filepath.Join(filepath.Dir(regPath), ".core", "cache") } - c, err := cache.New(cacheDir, 0) + c, err := cache.New(nil, cacheDir, 0) if err != nil { c = nil } diff --git a/internal/cmd/vm/cmd_container.go b/internal/cmd/vm/cmd_container.go index 38622a54..fa9246fe 100644 --- a/internal/cmd/vm/cmd_container.go +++ b/internal/cmd/vm/cmd_container.go @@ -4,7 +4,7 @@ import ( "context" "errors" "fmt" - "io" + goio "io" "os" "strings" "text/tabwriter" @@ -12,6 +12,7 @@ import ( "github.com/host-uk/core/pkg/container" "github.com/host-uk/core/pkg/i18n" + "github.com/host-uk/core/pkg/io" "github.com/spf13/cobra" ) @@ -68,7 +69,7 @@ func addVMRunCommand(parent *cobra.Command) { } func runContainer(image, name string, detach bool, memory, cpus, sshPort int) error { - manager, err := container.NewLinuxKitManager() + manager, err := container.NewLinuxKitManager(io.Local) if err != nil { return fmt.Errorf(i18n.T("i18n.fail.init", "container manager")+": %w", err) } @@ -126,7 +127,7 @@ func addVMPsCommand(parent *cobra.Command) { } func listContainers(all bool) error { - manager, err := container.NewLinuxKitManager() + manager, err := container.NewLinuxKitManager(io.Local) if err != nil { return fmt.Errorf(i18n.T("i18n.fail.init", "container manager")+": %w", err) } @@ -221,7 +222,7 @@ func addVMStopCommand(parent *cobra.Command) { } func stopContainer(id string) error { - manager, err := container.NewLinuxKitManager() + manager, err := container.NewLinuxKitManager(io.Local) if err != nil { return fmt.Errorf(i18n.T("i18n.fail.init", "container manager")+": %w", err) } @@ -290,7 +291,7 @@ func addVMLogsCommand(parent *cobra.Command) { } func viewLogs(id string, follow bool) error { - manager, err := container.NewLinuxKitManager() + manager, err := container.NewLinuxKitManager(io.Local) if err != nil { return fmt.Errorf(i18n.T("i18n.fail.init", "container manager")+": %w", err) } @@ -307,7 +308,7 @@ func viewLogs(id string, follow bool) error { } defer func() { _ = reader.Close() }() - _, err = io.Copy(os.Stdout, reader) + _, err = goio.Copy(os.Stdout, reader) return err } @@ -329,7 +330,7 @@ func addVMExecCommand(parent *cobra.Command) { } func execInContainer(id string, cmd []string) error { - manager, err := container.NewLinuxKitManager() + manager, err := container.NewLinuxKitManager(io.Local) if err != nil { return fmt.Errorf(i18n.T("i18n.fail.init", "container manager")+": %w", err) } diff --git a/internal/cmd/vm/cmd_templates.go b/internal/cmd/vm/cmd_templates.go index 31989df1..c03253e5 100644 --- a/internal/cmd/vm/cmd_templates.go +++ b/internal/cmd/vm/cmd_templates.go @@ -12,9 +12,12 @@ import ( "github.com/host-uk/core/pkg/container" "github.com/host-uk/core/pkg/i18n" + "github.com/host-uk/core/pkg/io" "github.com/spf13/cobra" ) +var templateManager = container.NewTemplateManager(io.Local) + // addVMTemplatesCommand adds the 'templates' command under vm. func addVMTemplatesCommand(parent *cobra.Command) { templatesCmd := &cobra.Command{ @@ -68,7 +71,7 @@ func addTemplatesVarsCommand(parent *cobra.Command) { } func listTemplates() error { - templates := container.ListTemplates() + templates := templateManager.ListTemplates() if len(templates) == 0 { fmt.Println(i18n.T("cmd.vm.templates.no_templates")) @@ -99,7 +102,7 @@ func listTemplates() error { } func showTemplate(name string) error { - content, err := container.GetTemplate(name) + content, err := templateManager.GetTemplate(name) if err != nil { return err } @@ -111,7 +114,7 @@ func showTemplate(name string) error { } func showTemplateVars(name string) error { - content, err := container.GetTemplate(name) + content, err := templateManager.GetTemplate(name) if err != nil { return err } @@ -148,7 +151,7 @@ func showTemplateVars(name string) error { // RunFromTemplate builds and runs a LinuxKit image from a template. func RunFromTemplate(templateName string, vars map[string]string, runOpts container.RunOptions) error { // Apply template with variables - content, err := container.ApplyTemplate(templateName, vars) + content, err := templateManager.ApplyTemplate(templateName, vars) if err != nil { return fmt.Errorf(i18n.T("common.error.failed", map[string]any{"Action": "apply template"})+": %w", err) } @@ -185,7 +188,7 @@ func RunFromTemplate(templateName string, vars map[string]string, runOpts contai fmt.Println() // Run the image - manager, err := container.NewLinuxKitManager() + manager, err := container.NewLinuxKitManager(io.Local) if err != nil { return fmt.Errorf(i18n.T("common.error.failed", map[string]any{"Action": "initialize container manager"})+": %w", err) } @@ -196,7 +199,7 @@ func RunFromTemplate(templateName string, vars map[string]string, runOpts contai ctx := context.Background() c, err := manager.Run(ctx, imagePath, runOpts) if err != nil { - return fmt.Errorf(i18n.T("common.error.failed", map[string]any{"Action": "run container"})+": %w", err) + return fmt.Errorf(i18n.T("i18n.fail.run", "container")+": %w", err) } if runOpts.Detach { diff --git a/pkg/build/buildcmd/cmd_release.go b/pkg/build/buildcmd/cmd_release.go index 330c96b3..e08be39b 100644 --- a/pkg/build/buildcmd/cmd_release.go +++ b/pkg/build/buildcmd/cmd_release.go @@ -9,6 +9,7 @@ import ( "github.com/host-uk/core/pkg/cli" "github.com/host-uk/core/pkg/framework/core" "github.com/host-uk/core/pkg/i18n" + "github.com/host-uk/core/pkg/io" "github.com/host-uk/core/pkg/release" ) @@ -50,7 +51,7 @@ func runRelease(ctx context.Context, dryRun bool, version string, draft, prerele } // Check for release config - if !release.ConfigExists(projectDir) { + if !release.ConfigExists(io.Local, projectDir) { cli.Print("%s %s\n", buildErrorStyle.Render(i18n.Label("error")), i18n.T("cmd.build.release.error.no_config"), @@ -60,7 +61,7 @@ func runRelease(ctx context.Context, dryRun bool, version string, draft, prerele } // Load configuration - cfg, err := release.LoadConfig(projectDir) + cfg, err := release.LoadConfig(io.Local, projectDir) if err != nil { return core.E("release", "load config", err) } diff --git a/pkg/cache/cache.go b/pkg/cache/cache.go index f660e421..ca60a63c 100644 --- a/pkg/cache/cache.go +++ b/pkg/cache/cache.go @@ -3,6 +3,8 @@ package cache import ( "encoding/json" + "errors" + "io/fs" "os" "path/filepath" "time" @@ -15,6 +17,7 @@ const DefaultTTL = 1 * time.Hour // Cache represents a file-based cache. type Cache struct { + medium io.Medium baseDir string ttl time.Duration } @@ -27,8 +30,13 @@ type Entry struct { } // New creates a new cache instance. -// If baseDir is empty, uses .core/cache in current directory -func New(baseDir string, ttl time.Duration) (*Cache, error) { +// If baseDir is empty, uses .core/cache in current directory. +// If m is nil, uses io.Local. +func New(m io.Medium, baseDir string, ttl time.Duration) (*Cache, error) { + if m == nil { + m = io.Local + } + if baseDir == "" { // Use .core/cache in current working directory cwd, err := os.Getwd() @@ -42,20 +50,21 @@ func New(baseDir string, ttl time.Duration) (*Cache, error) { ttl = DefaultTTL } - // Convert to absolute path for io.Local + // Convert to absolute path for consistency absBaseDir, err := filepath.Abs(baseDir) if err != nil { return nil, err } // Ensure cache directory exists - if err := io.Local.EnsureDir(absBaseDir); err != nil { + if err := m.EnsureDir(absBaseDir); err != nil { return nil, err } baseDir = absBaseDir return &Cache{ + medium: m, baseDir: baseDir, ttl: ttl, }, nil @@ -70,9 +79,9 @@ func (c *Cache) Path(key string) string { func (c *Cache) Get(key string, dest interface{}) (bool, error) { path := c.Path(key) - content, err := io.Local.Read(path) + content, err := c.medium.Read(path) if err != nil { - if os.IsNotExist(err) { + if errors.Is(err, fs.ErrNotExist) || os.IsNotExist(err) { return false, nil } return false, err @@ -119,15 +128,15 @@ func (c *Cache) Set(key string, data interface{}) error { return err } - // io.Local.Write creates parent directories automatically - return io.Local.Write(path, string(entryBytes)) + // medium.Write creates parent directories automatically + return c.medium.Write(path, string(entryBytes)) } // Delete removes an item from the cache. func (c *Cache) Delete(key string) error { path := c.Path(key) - err := io.Local.Delete(path) - if os.IsNotExist(err) { + err := c.medium.Delete(path) + if errors.Is(err, fs.ErrNotExist) || os.IsNotExist(err) { return nil } return err @@ -135,14 +144,14 @@ func (c *Cache) Delete(key string) error { // Clear removes all cached items. func (c *Cache) Clear() error { - return io.Local.DeleteAll(c.baseDir) + return c.medium.DeleteAll(c.baseDir) } // Age returns how old a cached item is, or -1 if not cached. func (c *Cache) Age(key string) time.Duration { path := c.Path(key) - content, err := io.Local.Read(path) + content, err := c.medium.Read(path) if err != nil { return -1 } diff --git a/pkg/cache/cache_test.go b/pkg/cache/cache_test.go new file mode 100644 index 00000000..87d52586 --- /dev/null +++ b/pkg/cache/cache_test.go @@ -0,0 +1,104 @@ +package cache_test + +import ( + "testing" + "time" + + "github.com/host-uk/core/pkg/cache" + "github.com/host-uk/core/pkg/io" +) + +func TestCache(t *testing.T) { + m := io.NewMockMedium() + // Use a path that MockMedium will understand + baseDir := "/tmp/cache" + c, err := cache.New(m, baseDir, 1*time.Minute) + if err != nil { + t.Fatalf("failed to create cache: %v", err) + } + + key := "test-key" + data := map[string]string{"foo": "bar"} + + // Test Set + if err := c.Set(key, data); err != nil { + t.Errorf("Set failed: %v", err) + } + + // Test Get + var retrieved map[string]string + found, err := c.Get(key, &retrieved) + if err != nil { + t.Errorf("Get failed: %v", err) + } + if !found { + t.Error("expected to find cached item") + } + if retrieved["foo"] != "bar" { + t.Errorf("expected foo=bar, got %v", retrieved["foo"]) + } + + // Test Age + age := c.Age(key) + if age < 0 { + t.Error("expected age >= 0") + } + + // Test Delete + if err := c.Delete(key); err != nil { + t.Errorf("Delete failed: %v", err) + } + found, err = c.Get(key, &retrieved) + if err != nil { + t.Errorf("Get after delete returned an unexpected error: %v", err) + } + if found { + t.Error("expected item to be deleted") + } + + // Test Expiry + cshort, err := cache.New(m, "/tmp/cache-short", 10*time.Millisecond) + if err != nil { + t.Fatalf("failed to create short-lived cache: %v", err) + } + if err := cshort.Set(key, data); err != nil { + t.Fatalf("Set for expiry test failed: %v", err) + } + time.Sleep(50 * time.Millisecond) + found, err = cshort.Get(key, &retrieved) + if err != nil { + t.Errorf("Get for expired item returned an unexpected error: %v", err) + } + if found { + t.Error("expected item to be expired") + } + + // Test Clear + if err := c.Set("key1", data); err != nil { + t.Fatalf("Set for clear test failed for key1: %v", err) + } + if err := c.Set("key2", data); err != nil { + t.Fatalf("Set for clear test failed for key2: %v", err) + } + if err := c.Clear(); err != nil { + t.Errorf("Clear failed: %v", err) + } + found, err = c.Get("key1", &retrieved) + if err != nil { + t.Errorf("Get after clear returned an unexpected error: %v", err) + } + if found { + t.Error("expected key1 to be cleared") + } +} + +func TestCacheDefaults(t *testing.T) { + // Test default Medium (io.Local) and default TTL + c, err := cache.New(nil, "", 0) + if err != nil { + t.Fatalf("failed to create cache with defaults: %v", err) + } + if c == nil { + t.Fatal("expected cache instance") + } +} diff --git a/pkg/container/linuxkit.go b/pkg/container/linuxkit.go index 2f2780af..a5371f73 100644 --- a/pkg/container/linuxkit.go +++ b/pkg/container/linuxkit.go @@ -17,16 +17,17 @@ import ( type LinuxKitManager struct { state *State hypervisor Hypervisor + medium io.Medium } // NewLinuxKitManager creates a new LinuxKit manager with auto-detected hypervisor. -func NewLinuxKitManager() (*LinuxKitManager, error) { +func NewLinuxKitManager(m io.Medium) (*LinuxKitManager, error) { statePath, err := DefaultStatePath() if err != nil { return nil, fmt.Errorf("failed to determine state path: %w", err) } - state, err := LoadState(statePath) + state, err := LoadState(m, statePath) if err != nil { return nil, fmt.Errorf("failed to load state: %w", err) } @@ -39,21 +40,23 @@ func NewLinuxKitManager() (*LinuxKitManager, error) { return &LinuxKitManager{ state: state, hypervisor: hypervisor, + medium: m, }, nil } // NewLinuxKitManagerWithHypervisor creates a manager with a specific hypervisor. -func NewLinuxKitManagerWithHypervisor(state *State, hypervisor Hypervisor) *LinuxKitManager { +func NewLinuxKitManagerWithHypervisor(m io.Medium, state *State, hypervisor Hypervisor) *LinuxKitManager { return &LinuxKitManager{ state: state, hypervisor: hypervisor, + medium: m, } } // Run starts a new LinuxKit VM from the given image. func (m *LinuxKitManager) Run(ctx context.Context, image string, opts RunOptions) (*Container, error) { // Validate image exists - if !io.Local.IsFile(image) { + if !m.medium.IsFile(image) { return nil, fmt.Errorf("image not found: %s", image) } @@ -87,7 +90,7 @@ func (m *LinuxKitManager) Run(ctx context.Context, image string, opts RunOptions } // Ensure logs directory exists - if err := EnsureLogsDir(); err != nil { + if err := EnsureLogsDir(m.medium); err != nil { return nil, fmt.Errorf("failed to create logs directory: %w", err) } @@ -329,35 +332,36 @@ func (m *LinuxKitManager) Logs(ctx context.Context, id string, follow bool) (goi return nil, fmt.Errorf("failed to determine log path: %w", err) } - if !io.Local.IsFile(logPath) { + if !m.medium.IsFile(logPath) { return nil, fmt.Errorf("no logs available for container: %s", id) } if !follow { // Simple case: just open and return the file - return os.Open(logPath) + return m.medium.Open(logPath) } // Follow mode: create a reader that tails the file - return newFollowReader(ctx, logPath) + return newFollowReader(ctx, m.medium, logPath) } // followReader implements goio.ReadCloser for following log files. type followReader struct { - file *os.File + file goio.ReadCloser ctx context.Context cancel context.CancelFunc reader *bufio.Reader + medium io.Medium + path string } -func newFollowReader(ctx context.Context, path string) (*followReader, error) { - file, err := os.Open(path) +func newFollowReader(ctx context.Context, m io.Medium, path string) (*followReader, error) { + file, err := m.Open(path) if err != nil { return nil, err } - // Seek to end - _, _ = file.Seek(0, goio.SeekEnd) + // Note: We don't seek here because Medium.Open doesn't guarantee Seekability. ctx, cancel := context.WithCancel(ctx) @@ -366,6 +370,8 @@ func newFollowReader(ctx context.Context, path string) (*followReader, error) { ctx: ctx, cancel: cancel, reader: bufio.NewReader(file), + medium: m, + path: path, }, nil } diff --git a/pkg/container/linuxkit_test.go b/pkg/container/linuxkit_test.go index 2a03cb07..b943898a 100644 --- a/pkg/container/linuxkit_test.go +++ b/pkg/container/linuxkit_test.go @@ -8,6 +8,7 @@ import ( "testing" "time" + "github.com/host-uk/core/pkg/io" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -63,11 +64,11 @@ func newTestManager(t *testing.T) (*LinuxKitManager, *MockHypervisor, string) { statePath := filepath.Join(tmpDir, "containers.json") - state, err := LoadState(statePath) + state, err := LoadState(io.Local, statePath) require.NoError(t, err) mock := NewMockHypervisor() - manager := NewLinuxKitManagerWithHypervisor(state, mock) + manager := NewLinuxKitManagerWithHypervisor(io.Local, state, mock) return manager, mock, tmpDir } @@ -75,10 +76,10 @@ func newTestManager(t *testing.T) (*LinuxKitManager, *MockHypervisor, string) { func TestNewLinuxKitManagerWithHypervisor_Good(t *testing.T) { tmpDir := t.TempDir() statePath := filepath.Join(tmpDir, "containers.json") - state, _ := LoadState(statePath) + state, _ := LoadState(io.Local, statePath) mock := NewMockHypervisor() - manager := NewLinuxKitManagerWithHypervisor(state, mock) + manager := NewLinuxKitManagerWithHypervisor(io.Local, state, mock) assert.NotNil(t, manager) assert.Equal(t, state, manager.State()) @@ -213,9 +214,9 @@ func TestLinuxKitManager_Stop_Bad_NotFound(t *testing.T) { func TestLinuxKitManager_Stop_Bad_NotRunning(t *testing.T) { _, _, tmpDir := newTestManager(t) statePath := filepath.Join(tmpDir, "containers.json") - state, err := LoadState(statePath) + state, err := LoadState(io.Local, statePath) require.NoError(t, err) - manager := NewLinuxKitManagerWithHypervisor(state, NewMockHypervisor()) + manager := NewLinuxKitManagerWithHypervisor(io.Local, state, NewMockHypervisor()) container := &Container{ ID: "abc12345", @@ -233,9 +234,9 @@ func TestLinuxKitManager_Stop_Bad_NotRunning(t *testing.T) { func TestLinuxKitManager_List_Good(t *testing.T) { _, _, tmpDir := newTestManager(t) statePath := filepath.Join(tmpDir, "containers.json") - state, err := LoadState(statePath) + state, err := LoadState(io.Local, statePath) require.NoError(t, err) - manager := NewLinuxKitManagerWithHypervisor(state, NewMockHypervisor()) + manager := NewLinuxKitManagerWithHypervisor(io.Local, state, NewMockHypervisor()) _ = state.Add(&Container{ID: "aaa11111", Status: StatusStopped}) _ = state.Add(&Container{ID: "bbb22222", Status: StatusStopped}) @@ -250,9 +251,9 @@ func TestLinuxKitManager_List_Good(t *testing.T) { func TestLinuxKitManager_List_Good_VerifiesRunningStatus(t *testing.T) { _, _, tmpDir := newTestManager(t) statePath := filepath.Join(tmpDir, "containers.json") - state, err := LoadState(statePath) + state, err := LoadState(io.Local, statePath) require.NoError(t, err) - manager := NewLinuxKitManagerWithHypervisor(state, NewMockHypervisor()) + manager := NewLinuxKitManagerWithHypervisor(io.Local, state, NewMockHypervisor()) // Add a "running" container with a fake PID that doesn't exist _ = state.Add(&Container{ @@ -475,7 +476,7 @@ func TestFollowReader_Read_Good_WithData(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() - reader, err := newFollowReader(ctx, logPath) + reader, err := newFollowReader(ctx, io.Local, logPath) require.NoError(t, err) defer func() { _ = reader.Close() }() @@ -506,7 +507,7 @@ func TestFollowReader_Read_Good_ContextCancel(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) - reader, err := newFollowReader(ctx, logPath) + reader, err := newFollowReader(ctx, io.Local, logPath) require.NoError(t, err) // Cancel the context @@ -528,7 +529,7 @@ func TestFollowReader_Close_Good(t *testing.T) { require.NoError(t, err) ctx := context.Background() - reader, err := newFollowReader(ctx, logPath) + reader, err := newFollowReader(ctx, io.Local, logPath) require.NoError(t, err) err = reader.Close() @@ -542,7 +543,7 @@ func TestFollowReader_Close_Good(t *testing.T) { func TestNewFollowReader_Bad_FileNotFound(t *testing.T) { ctx := context.Background() - _, err := newFollowReader(ctx, "/nonexistent/path/to/file.log") + _, err := newFollowReader(ctx, io.Local, "/nonexistent/path/to/file.log") assert.Error(t, err) } @@ -672,7 +673,7 @@ func TestLinuxKitManager_Run_Good_WithPortsAndVolumes(t *testing.T) { time.Sleep(50 * time.Millisecond) } -func TestFollowReader_Read_Good_ReaderError(t *testing.T) { +func TestFollowReader_Read_Bad_ReaderError(t *testing.T) { tmpDir := t.TempDir() logPath := filepath.Join(tmpDir, "test.log") @@ -681,7 +682,7 @@ func TestFollowReader_Read_Good_ReaderError(t *testing.T) { require.NoError(t, err) ctx := context.Background() - reader, err := newFollowReader(ctx, logPath) + reader, err := newFollowReader(ctx, io.Local, logPath) require.NoError(t, err) // Close the underlying file to cause read errors diff --git a/pkg/container/state.go b/pkg/container/state.go index e99bb051..376952c9 100644 --- a/pkg/container/state.go +++ b/pkg/container/state.go @@ -15,6 +15,7 @@ type State struct { Containers map[string]*Container `json:"containers"` mu sync.RWMutex + medium io.Medium filePath string } @@ -46,24 +47,25 @@ func DefaultLogsDir() (string, error) { } // NewState creates a new State instance. -func NewState(filePath string) *State { +func NewState(m io.Medium, filePath string) *State { return &State{ Containers: make(map[string]*Container), + medium: m, filePath: filePath, } } // LoadState loads the state from the given file path. // If the file doesn't exist, returns an empty state. -func LoadState(filePath string) (*State, error) { - state := NewState(filePath) +func LoadState(m io.Medium, filePath string) (*State, error) { + state := NewState(m, filePath) absPath, err := filepath.Abs(filePath) if err != nil { return nil, err } - content, err := io.Local.Read(absPath) + content, err := m.Read(absPath) if err != nil { if os.IsNotExist(err) { return state, nil @@ -93,8 +95,8 @@ func (s *State) SaveState() error { return err } - // io.Local.Write creates parent directories automatically - return io.Local.Write(absPath, string(data)) + // s.medium.Write creates parent directories automatically + return s.medium.Write(absPath, string(data)) } // Add adds a container to the state and persists it. @@ -168,10 +170,10 @@ func LogPath(id string) (string, error) { } // EnsureLogsDir ensures the logs directory exists. -func EnsureLogsDir() error { +func EnsureLogsDir(m io.Medium) error { logsDir, err := DefaultLogsDir() if err != nil { return err } - return io.Local.EnsureDir(logsDir) + return m.EnsureDir(logsDir) } diff --git a/pkg/container/state_test.go b/pkg/container/state_test.go index 68e6a023..a7c28003 100644 --- a/pkg/container/state_test.go +++ b/pkg/container/state_test.go @@ -6,12 +6,13 @@ import ( "testing" "time" + "github.com/host-uk/core/pkg/io" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestNewState_Good(t *testing.T) { - state := NewState("/tmp/test-state.json") + state := NewState(io.Local, "/tmp/test-state.json") assert.NotNil(t, state) assert.NotNil(t, state.Containers) @@ -23,7 +24,7 @@ func TestLoadState_Good_NewFile(t *testing.T) { tmpDir := t.TempDir() statePath := filepath.Join(tmpDir, "containers.json") - state, err := LoadState(statePath) + state, err := LoadState(io.Local, statePath) require.NoError(t, err) assert.NotNil(t, state) @@ -50,7 +51,7 @@ func TestLoadState_Good_ExistingFile(t *testing.T) { err := os.WriteFile(statePath, []byte(content), 0644) require.NoError(t, err) - state, err := LoadState(statePath) + state, err := LoadState(io.Local, statePath) require.NoError(t, err) assert.Len(t, state.Containers, 1) @@ -69,14 +70,14 @@ func TestLoadState_Bad_InvalidJSON(t *testing.T) { err := os.WriteFile(statePath, []byte("invalid json{"), 0644) require.NoError(t, err) - _, err = LoadState(statePath) + _, err = LoadState(io.Local, statePath) assert.Error(t, err) } func TestState_Add_Good(t *testing.T) { tmpDir := t.TempDir() statePath := filepath.Join(tmpDir, "containers.json") - state := NewState(statePath) + state := NewState(io.Local, statePath) container := &Container{ ID: "abc12345", @@ -103,7 +104,7 @@ func TestState_Add_Good(t *testing.T) { func TestState_Update_Good(t *testing.T) { tmpDir := t.TempDir() statePath := filepath.Join(tmpDir, "containers.json") - state := NewState(statePath) + state := NewState(io.Local, statePath) container := &Container{ ID: "abc12345", @@ -125,7 +126,7 @@ func TestState_Update_Good(t *testing.T) { func TestState_Remove_Good(t *testing.T) { tmpDir := t.TempDir() statePath := filepath.Join(tmpDir, "containers.json") - state := NewState(statePath) + state := NewState(io.Local, statePath) container := &Container{ ID: "abc12345", @@ -140,7 +141,7 @@ func TestState_Remove_Good(t *testing.T) { } func TestState_Get_Bad_NotFound(t *testing.T) { - state := NewState("/tmp/test-state.json") + state := NewState(io.Local, "/tmp/test-state.json") _, ok := state.Get("nonexistent") assert.False(t, ok) @@ -149,7 +150,7 @@ func TestState_Get_Bad_NotFound(t *testing.T) { func TestState_All_Good(t *testing.T) { tmpDir := t.TempDir() statePath := filepath.Join(tmpDir, "containers.json") - state := NewState(statePath) + state := NewState(io.Local, statePath) _ = state.Add(&Container{ID: "aaa11111"}) _ = state.Add(&Container{ID: "bbb22222"}) @@ -162,7 +163,7 @@ func TestState_All_Good(t *testing.T) { func TestState_SaveState_Good_CreatesDirectory(t *testing.T) { tmpDir := t.TempDir() nestedPath := filepath.Join(tmpDir, "nested", "dir", "containers.json") - state := NewState(nestedPath) + state := NewState(io.Local, nestedPath) _ = state.Add(&Container{ID: "abc12345"}) @@ -200,7 +201,7 @@ func TestLogPath_Good(t *testing.T) { func TestEnsureLogsDir_Good(t *testing.T) { // This test creates real directories - skip in CI if needed - err := EnsureLogsDir() + err := EnsureLogsDir(io.Local) assert.NoError(t, err) logsDir, _ := DefaultLogsDir() diff --git a/pkg/container/templates.go b/pkg/container/templates.go index 80ec3005..263337a6 100644 --- a/pkg/container/templates.go +++ b/pkg/container/templates.go @@ -38,17 +38,52 @@ var builtinTemplates = []Template{ }, } +// TemplateManager manages LinuxKit templates using a storage medium. +type TemplateManager struct { + medium io.Medium + workingDir string + homeDir string +} + +// NewTemplateManager creates a new TemplateManager instance. +func NewTemplateManager(m io.Medium) *TemplateManager { + tm := &TemplateManager{medium: m} + + // Default working and home directories from local system + // These can be overridden if needed. + if wd, err := os.Getwd(); err == nil { + tm.workingDir = wd + } + if home, err := os.UserHomeDir(); err == nil { + tm.homeDir = home + } + + return tm +} + +// WithWorkingDir sets the working directory for user template discovery. +func (tm *TemplateManager) WithWorkingDir(wd string) *TemplateManager { + tm.workingDir = wd + return tm +} + +// WithHomeDir sets the home directory for user template discovery. +func (tm *TemplateManager) WithHomeDir(home string) *TemplateManager { + tm.homeDir = home + return tm +} + // ListTemplates returns all available LinuxKit templates. // It combines embedded templates with any templates found in the user's // .core/linuxkit directory. -func ListTemplates() []Template { +func (tm *TemplateManager) ListTemplates() []Template { templates := make([]Template, len(builtinTemplates)) copy(templates, builtinTemplates) // Check for user templates in .core/linuxkit/ - userTemplatesDir := getUserTemplatesDir() + userTemplatesDir := tm.getUserTemplatesDir() if userTemplatesDir != "" { - userTemplates := scanUserTemplates(userTemplatesDir) + userTemplates := tm.scanUserTemplates(userTemplatesDir) templates = append(templates, userTemplates...) } @@ -57,7 +92,7 @@ func ListTemplates() []Template { // GetTemplate returns the content of a template by name. // It first checks embedded templates, then user templates. -func GetTemplate(name string) (string, error) { +func (tm *TemplateManager) GetTemplate(name string) (string, error) { // Check embedded templates first for _, t := range builtinTemplates { if t.Name == name { @@ -70,15 +105,18 @@ func GetTemplate(name string) (string, error) { } // Check user templates - userTemplatesDir := getUserTemplatesDir() + userTemplatesDir := tm.getUserTemplatesDir() if userTemplatesDir != "" { - templatePath := filepath.Join(userTemplatesDir, name+".yml") - if io.Local.IsFile(templatePath) { - content, err := io.Local.Read(templatePath) - if err != nil { - return "", fmt.Errorf("failed to read user template %s: %w", name, err) + // Check both .yml and .yaml extensions + for _, ext := range []string{".yml", ".yaml"} { + templatePath := filepath.Join(userTemplatesDir, name+ext) + if tm.medium.IsFile(templatePath) { + content, err := tm.medium.Read(templatePath) + if err != nil { + return "", fmt.Errorf("failed to read user template %s: %w", name, err) + } + return content, nil } - return content, nil } } @@ -86,11 +124,8 @@ func GetTemplate(name string) (string, error) { } // ApplyTemplate applies variable substitution to a template. -// It supports two syntaxes: -// - ${VAR} - required variable, returns error if not provided -// - ${VAR:-default} - variable with default value -func ApplyTemplate(name string, vars map[string]string) (string, error) { - content, err := GetTemplate(name) +func (tm *TemplateManager) ApplyTemplate(name string, vars map[string]string) (string, error) { + content, err := tm.GetTemplate(name) if err != nil { return "", err } @@ -191,35 +226,31 @@ func ExtractVariables(content string) (required []string, optional map[string]st // getUserTemplatesDir returns the path to user templates directory. // Returns empty string if the directory doesn't exist. -func getUserTemplatesDir() string { +func (tm *TemplateManager) getUserTemplatesDir() string { // Try workspace-relative .core/linuxkit first - cwd, err := os.Getwd() - if err == nil { - wsDir := filepath.Join(cwd, ".core", "linuxkit") - if io.Local.IsDir(wsDir) { + if tm.workingDir != "" { + wsDir := filepath.Join(tm.workingDir, ".core", "linuxkit") + if tm.medium.IsDir(wsDir) { return wsDir } } // Try home directory - home, err := os.UserHomeDir() - if err != nil { - return "" - } - - homeDir := filepath.Join(home, ".core", "linuxkit") - if io.Local.IsDir(homeDir) { - return homeDir + if tm.homeDir != "" { + homeDir := filepath.Join(tm.homeDir, ".core", "linuxkit") + if tm.medium.IsDir(homeDir) { + return homeDir + } } return "" } // scanUserTemplates scans a directory for .yml template files. -func scanUserTemplates(dir string) []Template { +func (tm *TemplateManager) scanUserTemplates(dir string) []Template { var templates []Template - entries, err := io.Local.List(dir) + entries, err := tm.medium.List(dir) if err != nil { return templates } @@ -250,7 +281,7 @@ func scanUserTemplates(dir string) []Template { } // Read file to extract description from comments - description := extractTemplateDescription(filepath.Join(dir, name)) + description := tm.extractTemplateDescription(filepath.Join(dir, name)) if description == "" { description = "User-defined template" } @@ -267,8 +298,8 @@ func scanUserTemplates(dir string) []Template { // extractTemplateDescription reads the first comment block from a YAML file // to use as a description. -func extractTemplateDescription(path string) string { - content, err := io.Local.Read(path) +func (tm *TemplateManager) extractTemplateDescription(path string) string { + content, err := tm.medium.Read(path) if err != nil { return "" } diff --git a/pkg/container/templates_test.go b/pkg/container/templates_test.go index e4a78aa5..c1db5a4e 100644 --- a/pkg/container/templates_test.go +++ b/pkg/container/templates_test.go @@ -6,12 +6,14 @@ import ( "strings" "testing" + "github.com/host-uk/core/pkg/io" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestListTemplates_Good(t *testing.T) { - templates := ListTemplates() + tm := NewTemplateManager(io.Local) + templates := tm.ListTemplates() // Should have at least the builtin templates assert.GreaterOrEqual(t, len(templates), 2) @@ -42,7 +44,8 @@ func TestListTemplates_Good(t *testing.T) { } func TestGetTemplate_Good_CoreDev(t *testing.T) { - content, err := GetTemplate("core-dev") + tm := NewTemplateManager(io.Local) + content, err := tm.GetTemplate("core-dev") require.NoError(t, err) assert.NotEmpty(t, content) @@ -53,7 +56,8 @@ func TestGetTemplate_Good_CoreDev(t *testing.T) { } func TestGetTemplate_Good_ServerPhp(t *testing.T) { - content, err := GetTemplate("server-php") + tm := NewTemplateManager(io.Local) + content, err := tm.GetTemplate("server-php") require.NoError(t, err) assert.NotEmpty(t, content) @@ -64,7 +68,8 @@ func TestGetTemplate_Good_ServerPhp(t *testing.T) { } func TestGetTemplate_Bad_NotFound(t *testing.T) { - _, err := GetTemplate("nonexistent-template") + tm := NewTemplateManager(io.Local) + _, err := tm.GetTemplate("nonexistent-template") assert.Error(t, err) assert.Contains(t, err.Error(), "template not found") @@ -162,11 +167,12 @@ func TestApplyVariables_Bad_MultipleMissing(t *testing.T) { } func TestApplyTemplate_Good(t *testing.T) { + tm := NewTemplateManager(io.Local) vars := map[string]string{ "SSH_KEY": "ssh-rsa AAAA... user@host", } - result, err := ApplyTemplate("core-dev", vars) + result, err := tm.ApplyTemplate("core-dev", vars) require.NoError(t, err) assert.NotEmpty(t, result) @@ -176,21 +182,23 @@ func TestApplyTemplate_Good(t *testing.T) { } func TestApplyTemplate_Bad_TemplateNotFound(t *testing.T) { + tm := NewTemplateManager(io.Local) vars := map[string]string{ "SSH_KEY": "test", } - _, err := ApplyTemplate("nonexistent", vars) + _, err := tm.ApplyTemplate("nonexistent", vars) assert.Error(t, err) assert.Contains(t, err.Error(), "template not found") } func TestApplyTemplate_Bad_MissingVariable(t *testing.T) { + tm := NewTemplateManager(io.Local) // server-php requires SSH_KEY vars := map[string]string{} // Missing required SSH_KEY - _, err := ApplyTemplate("server-php", vars) + _, err := tm.ApplyTemplate("server-php", vars) assert.Error(t, err) assert.Contains(t, err.Error(), "missing required variables") @@ -239,6 +247,7 @@ func TestExtractVariables_Good_OnlyDefaults(t *testing.T) { } func TestScanUserTemplates_Good(t *testing.T) { + tm := NewTemplateManager(io.Local) // Create a temporary directory with template files tmpDir := t.TempDir() @@ -255,7 +264,7 @@ kernel: err = os.WriteFile(filepath.Join(tmpDir, "readme.txt"), []byte("Not a template"), 0644) require.NoError(t, err) - templates := scanUserTemplates(tmpDir) + templates := tm.scanUserTemplates(tmpDir) assert.Len(t, templates, 1) assert.Equal(t, "custom", templates[0].Name) @@ -263,6 +272,7 @@ kernel: } func TestScanUserTemplates_Good_MultipleTemplates(t *testing.T) { + tm := NewTemplateManager(io.Local) tmpDir := t.TempDir() // Create multiple template files @@ -271,7 +281,7 @@ func TestScanUserTemplates_Good_MultipleTemplates(t *testing.T) { err = os.WriteFile(filepath.Join(tmpDir, "db.yaml"), []byte("# Database Server\nkernel:"), 0644) require.NoError(t, err) - templates := scanUserTemplates(tmpDir) + templates := tm.scanUserTemplates(tmpDir) assert.Len(t, templates, 2) @@ -285,20 +295,23 @@ func TestScanUserTemplates_Good_MultipleTemplates(t *testing.T) { } func TestScanUserTemplates_Good_EmptyDirectory(t *testing.T) { + tm := NewTemplateManager(io.Local) tmpDir := t.TempDir() - templates := scanUserTemplates(tmpDir) + templates := tm.scanUserTemplates(tmpDir) assert.Empty(t, templates) } func TestScanUserTemplates_Bad_NonexistentDirectory(t *testing.T) { - templates := scanUserTemplates("/nonexistent/path/to/templates") + tm := NewTemplateManager(io.Local) + templates := tm.scanUserTemplates("/nonexistent/path/to/templates") assert.Empty(t, templates) } func TestExtractTemplateDescription_Good(t *testing.T) { + tm := NewTemplateManager(io.Local) tmpDir := t.TempDir() path := filepath.Join(tmpDir, "test.yml") @@ -310,12 +323,13 @@ kernel: err := os.WriteFile(path, []byte(content), 0644) require.NoError(t, err) - desc := extractTemplateDescription(path) + desc := tm.extractTemplateDescription(path) assert.Equal(t, "My Template Description", desc) } func TestExtractTemplateDescription_Good_NoComments(t *testing.T) { + tm := NewTemplateManager(io.Local) tmpDir := t.TempDir() path := filepath.Join(tmpDir, "test.yml") @@ -325,13 +339,14 @@ func TestExtractTemplateDescription_Good_NoComments(t *testing.T) { err := os.WriteFile(path, []byte(content), 0644) require.NoError(t, err) - desc := extractTemplateDescription(path) + desc := tm.extractTemplateDescription(path) assert.Empty(t, desc) } func TestExtractTemplateDescription_Bad_FileNotFound(t *testing.T) { - desc := extractTemplateDescription("/nonexistent/file.yml") + tm := NewTemplateManager(io.Local) + desc := tm.extractTemplateDescription("/nonexistent/file.yml") assert.Empty(t, desc) } @@ -399,14 +414,8 @@ kernel: err = os.WriteFile(filepath.Join(coreDir, "user-custom.yml"), []byte(templateContent), 0644) require.NoError(t, err) - // Change to the temp directory - oldWd, err := os.Getwd() - require.NoError(t, err) - err = os.Chdir(tmpDir) - require.NoError(t, err) - defer func() { _ = os.Chdir(oldWd) }() - - templates := ListTemplates() + tm := NewTemplateManager(io.Local).WithWorkingDir(tmpDir) + templates := tm.ListTemplates() // Should have at least the builtin templates plus the user template assert.GreaterOrEqual(t, len(templates), 3) @@ -440,21 +449,39 @@ services: err = os.WriteFile(filepath.Join(coreDir, "my-user-template.yml"), []byte(templateContent), 0644) require.NoError(t, err) - // Change to the temp directory - oldWd, err := os.Getwd() - require.NoError(t, err) - err = os.Chdir(tmpDir) - require.NoError(t, err) - defer func() { _ = os.Chdir(oldWd) }() - - content, err := GetTemplate("my-user-template") + tm := NewTemplateManager(io.Local).WithWorkingDir(tmpDir) + content, err := tm.GetTemplate("my-user-template") require.NoError(t, err) assert.Contains(t, content, "kernel:") assert.Contains(t, content, "My user template") } +func TestGetTemplate_Good_UserTemplate_YamlExtension(t *testing.T) { + // Create a workspace directory with user templates + tmpDir := t.TempDir() + coreDir := filepath.Join(tmpDir, ".core", "linuxkit") + err := os.MkdirAll(coreDir, 0755) + require.NoError(t, err) + + // Create a user template with .yaml extension + templateContent := `# My yaml template +kernel: + image: linuxkit/kernel:6.6 +` + err = os.WriteFile(filepath.Join(coreDir, "my-yaml-template.yaml"), []byte(templateContent), 0644) + require.NoError(t, err) + + tm := NewTemplateManager(io.Local).WithWorkingDir(tmpDir) + content, err := tm.GetTemplate("my-yaml-template") + + require.NoError(t, err) + assert.Contains(t, content, "kernel:") + assert.Contains(t, content, "My yaml template") +} + func TestScanUserTemplates_Good_SkipsBuiltinNames(t *testing.T) { + tm := NewTemplateManager(io.Local) tmpDir := t.TempDir() // Create a template with a builtin name (should be skipped) @@ -465,7 +492,7 @@ func TestScanUserTemplates_Good_SkipsBuiltinNames(t *testing.T) { err = os.WriteFile(filepath.Join(tmpDir, "unique.yml"), []byte("# Unique\nkernel:"), 0644) require.NoError(t, err) - templates := scanUserTemplates(tmpDir) + templates := tm.scanUserTemplates(tmpDir) // Should only have the unique template, not the builtin name assert.Len(t, templates, 1) @@ -473,6 +500,7 @@ func TestScanUserTemplates_Good_SkipsBuiltinNames(t *testing.T) { } func TestScanUserTemplates_Good_SkipsDirectories(t *testing.T) { + tm := NewTemplateManager(io.Local) tmpDir := t.TempDir() // Create a subdirectory (should be skipped) @@ -483,13 +511,14 @@ func TestScanUserTemplates_Good_SkipsDirectories(t *testing.T) { err = os.WriteFile(filepath.Join(tmpDir, "valid.yml"), []byte("# Valid\nkernel:"), 0644) require.NoError(t, err) - templates := scanUserTemplates(tmpDir) + templates := tm.scanUserTemplates(tmpDir) assert.Len(t, templates, 1) assert.Equal(t, "valid", templates[0].Name) } func TestScanUserTemplates_Good_YamlExtension(t *testing.T) { + tm := NewTemplateManager(io.Local) tmpDir := t.TempDir() // Create templates with both extensions @@ -498,7 +527,7 @@ func TestScanUserTemplates_Good_YamlExtension(t *testing.T) { err = os.WriteFile(filepath.Join(tmpDir, "template2.yaml"), []byte("# Template 2\nkernel:"), 0644) require.NoError(t, err) - templates := scanUserTemplates(tmpDir) + templates := tm.scanUserTemplates(tmpDir) assert.Len(t, templates, 2) @@ -511,6 +540,7 @@ func TestScanUserTemplates_Good_YamlExtension(t *testing.T) { } func TestExtractTemplateDescription_Good_EmptyComment(t *testing.T) { + tm := NewTemplateManager(io.Local) tmpDir := t.TempDir() path := filepath.Join(tmpDir, "test.yml") @@ -523,12 +553,13 @@ kernel: err := os.WriteFile(path, []byte(content), 0644) require.NoError(t, err) - desc := extractTemplateDescription(path) + desc := tm.extractTemplateDescription(path) assert.Equal(t, "Actual description here", desc) } func TestExtractTemplateDescription_Good_MultipleEmptyComments(t *testing.T) { + tm := NewTemplateManager(io.Local) tmpDir := t.TempDir() path := filepath.Join(tmpDir, "test.yml") @@ -543,30 +574,20 @@ kernel: err := os.WriteFile(path, []byte(content), 0644) require.NoError(t, err) - desc := extractTemplateDescription(path) + desc := tm.extractTemplateDescription(path) assert.Equal(t, "Real description", desc) } func TestGetUserTemplatesDir_Good_NoDirectory(t *testing.T) { - // Save current working directory - oldWd, err := os.Getwd() - require.NoError(t, err) + tm := NewTemplateManager(io.Local).WithWorkingDir("/tmp/nonexistent-wd").WithHomeDir("/tmp/nonexistent-home") + dir := tm.getUserTemplatesDir() - // Create a temp directory without .core/linuxkit - tmpDir := t.TempDir() - err = os.Chdir(tmpDir) - require.NoError(t, err) - defer func() { _ = os.Chdir(oldWd) }() - - dir := getUserTemplatesDir() - - // Should return empty string since no templates dir exists - // (unless home dir has one) - assert.True(t, dir == "" || strings.Contains(dir, "linuxkit")) + assert.Empty(t, dir) } func TestScanUserTemplates_Good_DefaultDescription(t *testing.T) { + tm := NewTemplateManager(io.Local) tmpDir := t.TempDir() // Create a template without comments @@ -576,7 +597,7 @@ func TestScanUserTemplates_Good_DefaultDescription(t *testing.T) { err := os.WriteFile(filepath.Join(tmpDir, "nocomment.yml"), []byte(content), 0644) require.NoError(t, err) - templates := scanUserTemplates(tmpDir) + templates := tm.scanUserTemplates(tmpDir) assert.Len(t, templates, 1) assert.Equal(t, "User-defined template", templates[0].Description) diff --git a/pkg/devops/config.go b/pkg/devops/config.go index ab91790c..ee6a5178 100644 --- a/pkg/devops/config.go +++ b/pkg/devops/config.go @@ -62,15 +62,15 @@ func ConfigPath() (string, error) { return filepath.Join(home, ".core", "config.yaml"), nil } -// LoadConfig loads configuration from ~/.core/config.yaml. +// LoadConfig loads configuration from ~/.core/config.yaml using the provided medium. // Returns default config if file doesn't exist. -func LoadConfig() (*Config, error) { +func LoadConfig(m io.Medium) (*Config, error) { configPath, err := ConfigPath() if err != nil { return DefaultConfig(), nil } - content, err := io.Local.Read(configPath) + content, err := m.Read(configPath) if err != nil { if os.IsNotExist(err) { return DefaultConfig(), nil diff --git a/pkg/devops/config_test.go b/pkg/devops/config_test.go index cdd4ec7b..5ca5fa2b 100644 --- a/pkg/devops/config_test.go +++ b/pkg/devops/config_test.go @@ -5,6 +5,7 @@ import ( "path/filepath" "testing" + "github.com/host-uk/core/pkg/io" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -30,7 +31,7 @@ func TestLoadConfig_Good(t *testing.T) { t.Setenv("HOME", tempHome) defer func() { _ = os.Setenv("HOME", origHome) }() - cfg, err := LoadConfig() + cfg, err := LoadConfig(io.Local) assert.NoError(t, err) assert.Equal(t, DefaultConfig(), cfg) }) @@ -53,7 +54,7 @@ images: err = os.WriteFile(filepath.Join(coreDir, "config.yaml"), []byte(configData), 0644) require.NoError(t, err) - cfg, err := LoadConfig() + cfg, err := LoadConfig(io.Local) assert.NoError(t, err) assert.Equal(t, 2, cfg.Version) assert.Equal(t, "cdn", cfg.Images.Source) @@ -73,7 +74,7 @@ func TestLoadConfig_Bad(t *testing.T) { err = os.WriteFile(filepath.Join(coreDir, "config.yaml"), []byte("invalid: yaml: :"), 0644) require.NoError(t, err) - _, err = LoadConfig() + _, err = LoadConfig(io.Local) assert.Error(t, err) }) } @@ -127,7 +128,7 @@ images: err = os.WriteFile(filepath.Join(coreDir, "config.yaml"), []byte(configData), 0644) require.NoError(t, err) - cfg, err := LoadConfig() + cfg, err := LoadConfig(io.Local) assert.NoError(t, err) assert.Equal(t, 1, cfg.Version) assert.Equal(t, "github", cfg.Images.Source) @@ -197,7 +198,7 @@ images: err = os.WriteFile(filepath.Join(coreDir, "config.yaml"), []byte(tt.config), 0644) require.NoError(t, err) - cfg, err := LoadConfig() + cfg, err := LoadConfig(io.Local) assert.NoError(t, err) tt.check(t, cfg) }) @@ -246,7 +247,7 @@ func TestLoadConfig_Bad_UnreadableFile(t *testing.T) { err = os.WriteFile(configPath, []byte("version: 1"), 0000) require.NoError(t, err) - _, err = LoadConfig() + _, err = LoadConfig(io.Local) assert.Error(t, err) // Restore permissions so cleanup works diff --git a/pkg/devops/devops.go b/pkg/devops/devops.go index 9b0491c4..2cad57c2 100644 --- a/pkg/devops/devops.go +++ b/pkg/devops/devops.go @@ -15,29 +15,31 @@ import ( // DevOps manages the portable development environment. type DevOps struct { + medium io.Medium config *Config images *ImageManager container *container.LinuxKitManager } -// New creates a new DevOps instance. -func New() (*DevOps, error) { - cfg, err := LoadConfig() +// New creates a new DevOps instance using the provided medium. +func New(m io.Medium) (*DevOps, error) { + cfg, err := LoadConfig(m) if err != nil { return nil, fmt.Errorf("devops.New: failed to load config: %w", err) } - images, err := NewImageManager(cfg) + images, err := NewImageManager(m, cfg) if err != nil { return nil, fmt.Errorf("devops.New: failed to create image manager: %w", err) } - mgr, err := container.NewLinuxKitManager() + mgr, err := container.NewLinuxKitManager(io.Local) if err != nil { return nil, fmt.Errorf("devops.New: failed to create container manager: %w", err) } return &DevOps{ + medium: m, config: cfg, images: images, container: mgr, @@ -76,7 +78,7 @@ func (d *DevOps) IsInstalled() bool { if err != nil { return false } - return io.Local.IsFile(path) + return d.medium.IsFile(path) } // Install downloads and installs the dev image. diff --git a/pkg/devops/devops_test.go b/pkg/devops/devops_test.go index 4b75b8d0..2aef52fe 100644 --- a/pkg/devops/devops_test.go +++ b/pkg/devops/devops_test.go @@ -10,6 +10,7 @@ import ( "time" "github.com/host-uk/core/pkg/container" + "github.com/host-uk/core/pkg/io" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -69,7 +70,7 @@ func TestIsInstalled_Bad(t *testing.T) { t.Setenv("CORE_IMAGES_DIR", tempDir) // Create devops instance manually to avoid loading real config/images - d := &DevOps{} + d := &DevOps{medium: io.Local} assert.False(t, d.IsInstalled()) }) } @@ -84,7 +85,7 @@ func TestIsInstalled_Good(t *testing.T) { err := os.WriteFile(imagePath, []byte("fake image data"), 0644) require.NoError(t, err) - d := &DevOps{} + d := &DevOps{medium: io.Local} assert.True(t, d.IsInstalled()) }) } @@ -102,16 +103,16 @@ func TestDevOps_Status_Good(t *testing.T) { t.Setenv("CORE_IMAGES_DIR", tempDir) cfg := DefaultConfig() - mgr, err := NewImageManager(cfg) + mgr, err := NewImageManager(io.Local, cfg) require.NoError(t, err) // Setup mock container manager statePath := filepath.Join(tempDir, "containers.json") - state := container.NewState(statePath) + state := container.NewState(io.Local, statePath) h := &mockHypervisor{} - cm := container.NewLinuxKitManagerWithHypervisor(state, h) + cm := container.NewLinuxKitManagerWithHypervisor(io.Local, state, h) - d := &DevOps{ + d := &DevOps{medium: io.Local, images: mgr, container: cm, } @@ -143,15 +144,15 @@ func TestDevOps_Status_Good_NotInstalled(t *testing.T) { t.Setenv("CORE_IMAGES_DIR", tempDir) cfg := DefaultConfig() - mgr, err := NewImageManager(cfg) + mgr, err := NewImageManager(io.Local, cfg) require.NoError(t, err) statePath := filepath.Join(tempDir, "containers.json") - state := container.NewState(statePath) + state := container.NewState(io.Local, statePath) h := &mockHypervisor{} - cm := container.NewLinuxKitManagerWithHypervisor(state, h) + cm := container.NewLinuxKitManagerWithHypervisor(io.Local, state, h) - d := &DevOps{ + d := &DevOps{medium: io.Local, images: mgr, container: cm, } @@ -174,15 +175,15 @@ func TestDevOps_Status_Good_NoContainer(t *testing.T) { require.NoError(t, err) cfg := DefaultConfig() - mgr, err := NewImageManager(cfg) + mgr, err := NewImageManager(io.Local, cfg) require.NoError(t, err) statePath := filepath.Join(tempDir, "containers.json") - state := container.NewState(statePath) + state := container.NewState(io.Local, statePath) h := &mockHypervisor{} - cm := container.NewLinuxKitManagerWithHypervisor(state, h) + cm := container.NewLinuxKitManagerWithHypervisor(io.Local, state, h) - d := &DevOps{ + d := &DevOps{medium: io.Local, images: mgr, container: cm, } @@ -200,15 +201,15 @@ func TestDevOps_IsRunning_Good(t *testing.T) { t.Setenv("CORE_IMAGES_DIR", tempDir) cfg := DefaultConfig() - mgr, err := NewImageManager(cfg) + mgr, err := NewImageManager(io.Local, cfg) require.NoError(t, err) statePath := filepath.Join(tempDir, "containers.json") - state := container.NewState(statePath) + state := container.NewState(io.Local, statePath) h := &mockHypervisor{} - cm := container.NewLinuxKitManagerWithHypervisor(state, h) + cm := container.NewLinuxKitManagerWithHypervisor(io.Local, state, h) - d := &DevOps{ + d := &DevOps{medium: io.Local, images: mgr, container: cm, } @@ -233,15 +234,15 @@ func TestDevOps_IsRunning_Bad_NotRunning(t *testing.T) { t.Setenv("CORE_IMAGES_DIR", tempDir) cfg := DefaultConfig() - mgr, err := NewImageManager(cfg) + mgr, err := NewImageManager(io.Local, cfg) require.NoError(t, err) statePath := filepath.Join(tempDir, "containers.json") - state := container.NewState(statePath) + state := container.NewState(io.Local, statePath) h := &mockHypervisor{} - cm := container.NewLinuxKitManagerWithHypervisor(state, h) + cm := container.NewLinuxKitManagerWithHypervisor(io.Local, state, h) - d := &DevOps{ + d := &DevOps{medium: io.Local, images: mgr, container: cm, } @@ -256,15 +257,15 @@ func TestDevOps_IsRunning_Bad_ContainerStopped(t *testing.T) { t.Setenv("CORE_IMAGES_DIR", tempDir) cfg := DefaultConfig() - mgr, err := NewImageManager(cfg) + mgr, err := NewImageManager(io.Local, cfg) require.NoError(t, err) statePath := filepath.Join(tempDir, "containers.json") - state := container.NewState(statePath) + state := container.NewState(io.Local, statePath) h := &mockHypervisor{} - cm := container.NewLinuxKitManagerWithHypervisor(state, h) + cm := container.NewLinuxKitManagerWithHypervisor(io.Local, state, h) - d := &DevOps{ + d := &DevOps{medium: io.Local, images: mgr, container: cm, } @@ -289,15 +290,15 @@ func TestDevOps_findContainer_Good(t *testing.T) { t.Setenv("CORE_IMAGES_DIR", tempDir) cfg := DefaultConfig() - mgr, err := NewImageManager(cfg) + mgr, err := NewImageManager(io.Local, cfg) require.NoError(t, err) statePath := filepath.Join(tempDir, "containers.json") - state := container.NewState(statePath) + state := container.NewState(io.Local, statePath) h := &mockHypervisor{} - cm := container.NewLinuxKitManagerWithHypervisor(state, h) + cm := container.NewLinuxKitManagerWithHypervisor(io.Local, state, h) - d := &DevOps{ + d := &DevOps{medium: io.Local, images: mgr, container: cm, } @@ -324,15 +325,15 @@ func TestDevOps_findContainer_Bad_NotFound(t *testing.T) { t.Setenv("CORE_IMAGES_DIR", tempDir) cfg := DefaultConfig() - mgr, err := NewImageManager(cfg) + mgr, err := NewImageManager(io.Local, cfg) require.NoError(t, err) statePath := filepath.Join(tempDir, "containers.json") - state := container.NewState(statePath) + state := container.NewState(io.Local, statePath) h := &mockHypervisor{} - cm := container.NewLinuxKitManagerWithHypervisor(state, h) + cm := container.NewLinuxKitManagerWithHypervisor(io.Local, state, h) - d := &DevOps{ + d := &DevOps{medium: io.Local, images: mgr, container: cm, } @@ -347,15 +348,15 @@ func TestDevOps_Stop_Bad_NotFound(t *testing.T) { t.Setenv("CORE_IMAGES_DIR", tempDir) cfg := DefaultConfig() - mgr, err := NewImageManager(cfg) + mgr, err := NewImageManager(io.Local, cfg) require.NoError(t, err) statePath := filepath.Join(tempDir, "containers.json") - state := container.NewState(statePath) + state := container.NewState(io.Local, statePath) h := &mockHypervisor{} - cm := container.NewLinuxKitManagerWithHypervisor(state, h) + cm := container.NewLinuxKitManagerWithHypervisor(io.Local, state, h) - d := &DevOps{ + d := &DevOps{medium: io.Local, images: mgr, container: cm, } @@ -404,15 +405,15 @@ func TestDevOps_Boot_Bad_NotInstalled(t *testing.T) { t.Setenv("CORE_IMAGES_DIR", tempDir) cfg := DefaultConfig() - mgr, err := NewImageManager(cfg) + mgr, err := NewImageManager(io.Local, cfg) require.NoError(t, err) statePath := filepath.Join(tempDir, "containers.json") - state := container.NewState(statePath) + state := container.NewState(io.Local, statePath) h := &mockHypervisor{} - cm := container.NewLinuxKitManagerWithHypervisor(state, h) + cm := container.NewLinuxKitManagerWithHypervisor(io.Local, state, h) - d := &DevOps{ + d := &DevOps{medium: io.Local, images: mgr, container: cm, } @@ -432,15 +433,15 @@ func TestDevOps_Boot_Bad_AlreadyRunning(t *testing.T) { require.NoError(t, err) cfg := DefaultConfig() - mgr, err := NewImageManager(cfg) + mgr, err := NewImageManager(io.Local, cfg) require.NoError(t, err) statePath := filepath.Join(tempDir, "containers.json") - state := container.NewState(statePath) + state := container.NewState(io.Local, statePath) h := &mockHypervisor{} - cm := container.NewLinuxKitManagerWithHypervisor(state, h) + cm := container.NewLinuxKitManagerWithHypervisor(io.Local, state, h) - d := &DevOps{ + d := &DevOps{medium: io.Local, images: mgr, container: cm, } @@ -471,7 +472,7 @@ func TestDevOps_Status_Good_WithImageVersion(t *testing.T) { require.NoError(t, err) cfg := DefaultConfig() - mgr, err := NewImageManager(cfg) + mgr, err := NewImageManager(io.Local, cfg) require.NoError(t, err) // Manually set manifest with version info @@ -481,11 +482,11 @@ func TestDevOps_Status_Good_WithImageVersion(t *testing.T) { } statePath := filepath.Join(tempDir, "containers.json") - state := container.NewState(statePath) + state := container.NewState(io.Local, statePath) h := &mockHypervisor{} - cm := container.NewLinuxKitManagerWithHypervisor(state, h) + cm := container.NewLinuxKitManagerWithHypervisor(io.Local, state, h) - d := &DevOps{ + d := &DevOps{medium: io.Local, config: cfg, images: mgr, container: cm, @@ -502,15 +503,15 @@ func TestDevOps_findContainer_Good_MultipleContainers(t *testing.T) { t.Setenv("CORE_IMAGES_DIR", tempDir) cfg := DefaultConfig() - mgr, err := NewImageManager(cfg) + mgr, err := NewImageManager(io.Local, cfg) require.NoError(t, err) statePath := filepath.Join(tempDir, "containers.json") - state := container.NewState(statePath) + state := container.NewState(io.Local, statePath) h := &mockHypervisor{} - cm := container.NewLinuxKitManagerWithHypervisor(state, h) + cm := container.NewLinuxKitManagerWithHypervisor(io.Local, state, h) - d := &DevOps{ + d := &DevOps{medium: io.Local, images: mgr, container: cm, } @@ -547,15 +548,15 @@ func TestDevOps_Status_Good_ContainerWithUptime(t *testing.T) { t.Setenv("CORE_IMAGES_DIR", tempDir) cfg := DefaultConfig() - mgr, err := NewImageManager(cfg) + mgr, err := NewImageManager(io.Local, cfg) require.NoError(t, err) statePath := filepath.Join(tempDir, "containers.json") - state := container.NewState(statePath) + state := container.NewState(io.Local, statePath) h := &mockHypervisor{} - cm := container.NewLinuxKitManagerWithHypervisor(state, h) + cm := container.NewLinuxKitManagerWithHypervisor(io.Local, state, h) - d := &DevOps{ + d := &DevOps{medium: io.Local, images: mgr, container: cm, } @@ -584,15 +585,15 @@ func TestDevOps_IsRunning_Bad_DifferentContainerName(t *testing.T) { t.Setenv("CORE_IMAGES_DIR", tempDir) cfg := DefaultConfig() - mgr, err := NewImageManager(cfg) + mgr, err := NewImageManager(io.Local, cfg) require.NoError(t, err) statePath := filepath.Join(tempDir, "containers.json") - state := container.NewState(statePath) + state := container.NewState(io.Local, statePath) h := &mockHypervisor{} - cm := container.NewLinuxKitManagerWithHypervisor(state, h) + cm := container.NewLinuxKitManagerWithHypervisor(io.Local, state, h) - d := &DevOps{ + d := &DevOps{medium: io.Local, images: mgr, container: cm, } @@ -626,15 +627,15 @@ func TestDevOps_Boot_Good_FreshFlag(t *testing.T) { require.NoError(t, err) cfg := DefaultConfig() - mgr, err := NewImageManager(cfg) + mgr, err := NewImageManager(io.Local, cfg) require.NoError(t, err) statePath := filepath.Join(tempDir, "containers.json") - state := container.NewState(statePath) + state := container.NewState(io.Local, statePath) h := &mockHypervisor{} - cm := container.NewLinuxKitManagerWithHypervisor(state, h) + cm := container.NewLinuxKitManagerWithHypervisor(io.Local, state, h) - d := &DevOps{ + d := &DevOps{medium: io.Local, images: mgr, container: cm, } @@ -668,15 +669,15 @@ func TestDevOps_Stop_Bad_ContainerNotRunning(t *testing.T) { t.Setenv("CORE_IMAGES_DIR", tempDir) cfg := DefaultConfig() - mgr, err := NewImageManager(cfg) + mgr, err := NewImageManager(io.Local, cfg) require.NoError(t, err) statePath := filepath.Join(tempDir, "containers.json") - state := container.NewState(statePath) + state := container.NewState(io.Local, statePath) h := &mockHypervisor{} - cm := container.NewLinuxKitManagerWithHypervisor(state, h) + cm := container.NewLinuxKitManagerWithHypervisor(io.Local, state, h) - d := &DevOps{ + d := &DevOps{medium: io.Local, images: mgr, container: cm, } @@ -710,15 +711,15 @@ func TestDevOps_Boot_Good_FreshWithNoExisting(t *testing.T) { require.NoError(t, err) cfg := DefaultConfig() - mgr, err := NewImageManager(cfg) + mgr, err := NewImageManager(io.Local, cfg) require.NoError(t, err) statePath := filepath.Join(tempDir, "containers.json") - state := container.NewState(statePath) + state := container.NewState(io.Local, statePath) h := &mockHypervisor{} - cm := container.NewLinuxKitManagerWithHypervisor(state, h) + cm := container.NewLinuxKitManagerWithHypervisor(io.Local, state, h) - d := &DevOps{ + d := &DevOps{medium: io.Local, images: mgr, container: cm, } @@ -750,10 +751,10 @@ func TestDevOps_Install_Delegates(t *testing.T) { t.Setenv("CORE_IMAGES_DIR", tempDir) cfg := DefaultConfig() - mgr, err := NewImageManager(cfg) + mgr, err := NewImageManager(io.Local, cfg) require.NoError(t, err) - d := &DevOps{ + d := &DevOps{medium: io.Local, images: mgr, } @@ -768,10 +769,10 @@ func TestDevOps_CheckUpdate_Delegates(t *testing.T) { t.Setenv("CORE_IMAGES_DIR", tempDir) cfg := DefaultConfig() - mgr, err := NewImageManager(cfg) + mgr, err := NewImageManager(io.Local, cfg) require.NoError(t, err) - d := &DevOps{ + d := &DevOps{medium: io.Local, images: mgr, } @@ -792,15 +793,15 @@ func TestDevOps_Boot_Good_Success(t *testing.T) { require.NoError(t, err) cfg := DefaultConfig() - mgr, err := NewImageManager(cfg) + mgr, err := NewImageManager(io.Local, cfg) require.NoError(t, err) statePath := filepath.Join(tempDir, "containers.json") - state := container.NewState(statePath) + state := container.NewState(io.Local, statePath) h := &mockHypervisor{} - cm := container.NewLinuxKitManagerWithHypervisor(state, h) + cm := container.NewLinuxKitManagerWithHypervisor(io.Local, state, h) - d := &DevOps{ + d := &DevOps{medium: io.Local, images: mgr, container: cm, } @@ -816,10 +817,10 @@ func TestDevOps_Config(t *testing.T) { t.Setenv("CORE_IMAGES_DIR", tempDir) cfg := DefaultConfig() - mgr, err := NewImageManager(cfg) + mgr, err := NewImageManager(io.Local, cfg) require.NoError(t, err) - d := &DevOps{ + d := &DevOps{medium: io.Local, config: cfg, images: mgr, } diff --git a/pkg/devops/images.go b/pkg/devops/images.go index e6a93edc..7f2b5745 100644 --- a/pkg/devops/images.go +++ b/pkg/devops/images.go @@ -14,6 +14,7 @@ import ( // ImageManager handles image downloads and updates. type ImageManager struct { + medium io.Medium config *Config manifest *Manifest sources []sources.ImageSource @@ -21,6 +22,7 @@ type ImageManager struct { // Manifest tracks installed images. type Manifest struct { + medium io.Medium Images map[string]ImageInfo `json:"images"` path string } @@ -34,20 +36,20 @@ type ImageInfo struct { } // NewImageManager creates a new image manager. -func NewImageManager(cfg *Config) (*ImageManager, error) { +func NewImageManager(m io.Medium, cfg *Config) (*ImageManager, error) { imagesDir, err := ImagesDir() if err != nil { return nil, err } // Ensure images directory exists - if err := io.Local.EnsureDir(imagesDir); err != nil { + if err := m.EnsureDir(imagesDir); err != nil { return nil, err } // Load or create manifest manifestPath := filepath.Join(imagesDir, "manifest.json") - manifest, err := loadManifest(manifestPath) + manifest, err := loadManifest(m, manifestPath) if err != nil { return nil, err } @@ -75,6 +77,7 @@ func NewImageManager(cfg *Config) (*ImageManager, error) { } return &ImageManager{ + medium: m, config: cfg, manifest: manifest, sources: srcs, @@ -87,7 +90,7 @@ func (m *ImageManager) IsInstalled() bool { if err != nil { return false } - return io.Local.IsFile(path) + return m.medium.IsFile(path) } // Install downloads and installs the dev image. @@ -118,7 +121,7 @@ func (m *ImageManager) Install(ctx context.Context, progress func(downloaded, to fmt.Printf("Downloading %s from %s...\n", ImageName(), src.Name()) // Download - if err := src.Download(ctx, imagesDir, progress); err != nil { + if err := src.Download(ctx, m.medium, imagesDir, progress); err != nil { return err } @@ -161,26 +164,28 @@ func (m *ImageManager) CheckUpdate(ctx context.Context) (current, latest string, return current, latest, hasUpdate, nil } -func loadManifest(path string) (*Manifest, error) { - m := &Manifest{ +func loadManifest(m io.Medium, path string) (*Manifest, error) { + manifest := &Manifest{ + medium: m, Images: make(map[string]ImageInfo), path: path, } - content, err := io.Local.Read(path) + content, err := m.Read(path) if err != nil { if os.IsNotExist(err) { - return m, nil + return manifest, nil } return nil, err } - if err := json.Unmarshal([]byte(content), m); err != nil { + if err := json.Unmarshal([]byte(content), manifest); err != nil { return nil, err } - m.path = path + manifest.medium = m + manifest.path = path - return m, nil + return manifest, nil } // Save writes the manifest to disk. @@ -189,5 +194,5 @@ func (m *Manifest) Save() error { if err != nil { return err } - return io.Local.Write(m.path, string(data)) + return m.medium.Write(m.path, string(data)) } diff --git a/pkg/devops/images_test.go b/pkg/devops/images_test.go index 8252efb5..72eeb3df 100644 --- a/pkg/devops/images_test.go +++ b/pkg/devops/images_test.go @@ -8,6 +8,7 @@ import ( "time" "github.com/host-uk/core/pkg/devops/sources" + "github.com/host-uk/core/pkg/io" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -17,7 +18,7 @@ func TestImageManager_Good_IsInstalled(t *testing.T) { t.Setenv("CORE_IMAGES_DIR", tmpDir) cfg := DefaultConfig() - mgr, err := NewImageManager(cfg) + mgr, err := NewImageManager(io.Local, cfg) require.NoError(t, err) // Not installed yet @@ -40,7 +41,7 @@ func TestNewImageManager_Good(t *testing.T) { cfg := DefaultConfig() cfg.Images.Source = "cdn" - mgr, err := NewImageManager(cfg) + mgr, err := NewImageManager(io.Local, cfg) assert.NoError(t, err) assert.NotNil(t, mgr) assert.Len(t, mgr.sources, 1) @@ -54,7 +55,7 @@ func TestNewImageManager_Good(t *testing.T) { cfg := DefaultConfig() cfg.Images.Source = "github" - mgr, err := NewImageManager(cfg) + mgr, err := NewImageManager(io.Local, cfg) assert.NoError(t, err) assert.NotNil(t, mgr) assert.Len(t, mgr.sources, 1) @@ -67,6 +68,7 @@ func TestManifest_Save(t *testing.T) { path := filepath.Join(tmpDir, "manifest.json") m := &Manifest{ + medium: io.Local, Images: make(map[string]ImageInfo), path: path, } @@ -84,7 +86,7 @@ func TestManifest_Save(t *testing.T) { assert.NoError(t, err) // Reload - m2, err := loadManifest(path) + m2, err := loadManifest(io.Local, path) assert.NoError(t, err) assert.Equal(t, "1.0.0", m2.Images["test.img"].Version) } @@ -96,7 +98,7 @@ func TestLoadManifest_Bad(t *testing.T) { err := os.WriteFile(path, []byte("invalid json"), 0644) require.NoError(t, err) - _, err = loadManifest(path) + _, err = loadManifest(io.Local, path) assert.Error(t, err) }) } @@ -107,7 +109,7 @@ func TestCheckUpdate_Bad(t *testing.T) { t.Setenv("CORE_IMAGES_DIR", tmpDir) cfg := DefaultConfig() - mgr, err := NewImageManager(cfg) + mgr, err := NewImageManager(io.Local, cfg) require.NoError(t, err) _, _, _, err = mgr.CheckUpdate(context.Background()) @@ -123,7 +125,7 @@ func TestNewImageManager_Good_AutoSource(t *testing.T) { cfg := DefaultConfig() cfg.Images.Source = "auto" - mgr, err := NewImageManager(cfg) + mgr, err := NewImageManager(io.Local, cfg) assert.NoError(t, err) assert.NotNil(t, mgr) assert.Len(t, mgr.sources, 2) // github and cdn @@ -136,7 +138,7 @@ func TestNewImageManager_Good_UnknownSourceFallsToAuto(t *testing.T) { cfg := DefaultConfig() cfg.Images.Source = "unknown" - mgr, err := NewImageManager(cfg) + mgr, err := NewImageManager(io.Local, cfg) assert.NoError(t, err) assert.NotNil(t, mgr) assert.Len(t, mgr.sources, 2) // falls to default (auto) which is github + cdn @@ -146,7 +148,7 @@ func TestLoadManifest_Good_Empty(t *testing.T) { tmpDir := t.TempDir() path := filepath.Join(tmpDir, "nonexistent.json") - m, err := loadManifest(path) + m, err := loadManifest(io.Local, path) assert.NoError(t, err) assert.NotNil(t, m) assert.NotNil(t, m.Images) @@ -162,7 +164,7 @@ func TestLoadManifest_Good_ExistingData(t *testing.T) { err := os.WriteFile(path, []byte(data), 0644) require.NoError(t, err) - m, err := loadManifest(path) + m, err := loadManifest(io.Local, path) assert.NoError(t, err) assert.NotNil(t, m) assert.Equal(t, "2.0.0", m.Images["test.img"].Version) @@ -187,6 +189,7 @@ func TestManifest_Save_Good_CreatesDirs(t *testing.T) { nestedPath := filepath.Join(tmpDir, "nested", "dir", "manifest.json") m := &Manifest{ + medium: io.Local, Images: make(map[string]ImageInfo), path: nestedPath, } @@ -207,6 +210,7 @@ func TestManifest_Save_Good_Overwrite(t *testing.T) { // First save m1 := &Manifest{ + medium: io.Local, Images: make(map[string]ImageInfo), path: path, } @@ -216,6 +220,7 @@ func TestManifest_Save_Good_Overwrite(t *testing.T) { // Second save with different data m2 := &Manifest{ + medium: io.Local, Images: make(map[string]ImageInfo), path: path, } @@ -224,7 +229,7 @@ func TestManifest_Save_Good_Overwrite(t *testing.T) { require.NoError(t, err) // Verify second data - loaded, err := loadManifest(path) + loaded, err := loadManifest(io.Local, path) assert.NoError(t, err) assert.Equal(t, "2.0.0", loaded.Images["other.img"].Version) _, exists := loaded.Images["test.img"] @@ -237,8 +242,9 @@ func TestImageManager_Install_Bad_NoSourceAvailable(t *testing.T) { // Create manager with empty sources mgr := &ImageManager{ + medium: io.Local, config: DefaultConfig(), - manifest: &Manifest{Images: make(map[string]ImageInfo), path: filepath.Join(tmpDir, "manifest.json")}, + manifest: &Manifest{medium: io.Local, Images: make(map[string]ImageInfo), path: filepath.Join(tmpDir, "manifest.json")}, sources: nil, // no sources } @@ -253,7 +259,7 @@ func TestNewImageManager_Good_CreatesDir(t *testing.T) { t.Setenv("CORE_IMAGES_DIR", imagesDir) cfg := DefaultConfig() - mgr, err := NewImageManager(cfg) + mgr, err := NewImageManager(io.Local, cfg) assert.NoError(t, err) assert.NotNil(t, mgr) @@ -277,7 +283,7 @@ func (m *mockImageSource) Available() bool { return m.available } func (m *mockImageSource) LatestVersion(ctx context.Context) (string, error) { return m.latestVersion, m.latestErr } -func (m *mockImageSource) Download(ctx context.Context, dest string, progress func(downloaded, total int64)) error { +func (m *mockImageSource) Download(ctx context.Context, medium io.Medium, dest string, progress func(downloaded, total int64)) error { if m.downloadErr != nil { return m.downloadErr } @@ -297,8 +303,9 @@ func TestImageManager_Install_Good_WithMockSource(t *testing.T) { } mgr := &ImageManager{ + medium: io.Local, config: DefaultConfig(), - manifest: &Manifest{Images: make(map[string]ImageInfo), path: filepath.Join(tmpDir, "manifest.json")}, + manifest: &Manifest{medium: io.Local, Images: make(map[string]ImageInfo), path: filepath.Join(tmpDir, "manifest.json")}, sources: []sources.ImageSource{mock}, } @@ -325,8 +332,9 @@ func TestImageManager_Install_Bad_DownloadError(t *testing.T) { } mgr := &ImageManager{ + medium: io.Local, config: DefaultConfig(), - manifest: &Manifest{Images: make(map[string]ImageInfo), path: filepath.Join(tmpDir, "manifest.json")}, + manifest: &Manifest{medium: io.Local, Images: make(map[string]ImageInfo), path: filepath.Join(tmpDir, "manifest.json")}, sources: []sources.ImageSource{mock}, } @@ -345,8 +353,9 @@ func TestImageManager_Install_Bad_VersionError(t *testing.T) { } mgr := &ImageManager{ + medium: io.Local, config: DefaultConfig(), - manifest: &Manifest{Images: make(map[string]ImageInfo), path: filepath.Join(tmpDir, "manifest.json")}, + manifest: &Manifest{medium: io.Local, Images: make(map[string]ImageInfo), path: filepath.Join(tmpDir, "manifest.json")}, sources: []sources.ImageSource{mock}, } @@ -370,8 +379,9 @@ func TestImageManager_Install_Good_SkipsUnavailableSource(t *testing.T) { } mgr := &ImageManager{ + medium: io.Local, config: DefaultConfig(), - manifest: &Manifest{Images: make(map[string]ImageInfo), path: filepath.Join(tmpDir, "manifest.json")}, + manifest: &Manifest{medium: io.Local, Images: make(map[string]ImageInfo), path: filepath.Join(tmpDir, "manifest.json")}, sources: []sources.ImageSource{unavailableMock, availableMock}, } @@ -394,8 +404,10 @@ func TestImageManager_CheckUpdate_Good_WithMockSource(t *testing.T) { } mgr := &ImageManager{ + medium: io.Local, config: DefaultConfig(), manifest: &Manifest{ + medium: io.Local, Images: map[string]ImageInfo{ ImageName(): {Version: "v1.0.0", Source: "mock"}, }, @@ -422,8 +434,10 @@ func TestImageManager_CheckUpdate_Good_NoUpdate(t *testing.T) { } mgr := &ImageManager{ + medium: io.Local, config: DefaultConfig(), manifest: &Manifest{ + medium: io.Local, Images: map[string]ImageInfo{ ImageName(): {Version: "v1.0.0", Source: "mock"}, }, @@ -449,8 +463,10 @@ func TestImageManager_CheckUpdate_Bad_NoSource(t *testing.T) { } mgr := &ImageManager{ + medium: io.Local, config: DefaultConfig(), manifest: &Manifest{ + medium: io.Local, Images: map[string]ImageInfo{ ImageName(): {Version: "v1.0.0", Source: "mock"}, }, @@ -475,8 +491,10 @@ func TestImageManager_CheckUpdate_Bad_VersionError(t *testing.T) { } mgr := &ImageManager{ + medium: io.Local, config: DefaultConfig(), manifest: &Manifest{ + medium: io.Local, Images: map[string]ImageInfo{ ImageName(): {Version: "v1.0.0", Source: "mock"}, }, @@ -495,8 +513,9 @@ func TestImageManager_Install_Bad_EmptySources(t *testing.T) { t.Setenv("CORE_IMAGES_DIR", tmpDir) mgr := &ImageManager{ + medium: io.Local, config: DefaultConfig(), - manifest: &Manifest{Images: make(map[string]ImageInfo), path: filepath.Join(tmpDir, "manifest.json")}, + manifest: &Manifest{medium: io.Local, Images: make(map[string]ImageInfo), path: filepath.Join(tmpDir, "manifest.json")}, sources: []sources.ImageSource{}, // Empty slice, not nil } @@ -513,8 +532,9 @@ func TestImageManager_Install_Bad_AllUnavailable(t *testing.T) { mock2 := &mockImageSource{name: "mock2", available: false} mgr := &ImageManager{ + medium: io.Local, config: DefaultConfig(), - manifest: &Manifest{Images: make(map[string]ImageInfo), path: filepath.Join(tmpDir, "manifest.json")}, + manifest: &Manifest{medium: io.Local, Images: make(map[string]ImageInfo), path: filepath.Join(tmpDir, "manifest.json")}, sources: []sources.ImageSource{mock1, mock2}, } @@ -531,8 +551,10 @@ func TestImageManager_CheckUpdate_Good_FirstSourceUnavailable(t *testing.T) { available := &mockImageSource{name: "available", available: true, latestVersion: "v2.0.0"} mgr := &ImageManager{ + medium: io.Local, config: DefaultConfig(), manifest: &Manifest{ + medium: io.Local, Images: map[string]ImageInfo{ ImageName(): {Version: "v1.0.0", Source: "available"}, }, diff --git a/pkg/devops/serve.go b/pkg/devops/serve.go index 78f784b1..1e0dc802 100644 --- a/pkg/devops/serve.go +++ b/pkg/devops/serve.go @@ -6,6 +6,8 @@ import ( "os" "os/exec" "path/filepath" + + "github.com/host-uk/core/pkg/io" ) // ServeOptions configures the dev server. @@ -39,7 +41,7 @@ func (d *DevOps) Serve(ctx context.Context, projectDir string, opts ServeOptions } // Detect and run serve command - serveCmd := DetectServeCommand(servePath) + serveCmd := DetectServeCommand(d.medium, servePath) fmt.Printf("Starting server: %s\n", serveCmd) fmt.Printf("Listening on http://localhost:%d\n", opts.Port) @@ -69,36 +71,36 @@ func (d *DevOps) mountProject(ctx context.Context, path string) error { } // DetectServeCommand auto-detects the serve command for a project. -func DetectServeCommand(projectDir string) string { +func DetectServeCommand(m io.Medium, projectDir string) string { // Laravel/Octane - if hasFile(projectDir, "artisan") { + if hasFile(m, projectDir, "artisan") { return "php artisan octane:start --host=0.0.0.0 --port=8000" } // Node.js with dev script - if hasFile(projectDir, "package.json") { - if hasPackageScript(projectDir, "dev") { + if hasFile(m, projectDir, "package.json") { + if hasPackageScript(m, projectDir, "dev") { return "npm run dev -- --host 0.0.0.0" } - if hasPackageScript(projectDir, "start") { + if hasPackageScript(m, projectDir, "start") { return "npm start" } } // PHP with composer - if hasFile(projectDir, "composer.json") { + if hasFile(m, projectDir, "composer.json") { return "frankenphp php-server -l :8000" } // Go - if hasFile(projectDir, "go.mod") { - if hasFile(projectDir, "main.go") { + if hasFile(m, projectDir, "go.mod") { + if hasFile(m, projectDir, "main.go") { return "go run ." } } // Python Django - if hasFile(projectDir, "manage.py") { + if hasFile(m, projectDir, "manage.py") { return "python manage.py runserver 0.0.0.0:8000" } diff --git a/pkg/devops/serve_test.go b/pkg/devops/serve_test.go index 54e1949f..57dc8362 100644 --- a/pkg/devops/serve_test.go +++ b/pkg/devops/serve_test.go @@ -5,6 +5,7 @@ import ( "path/filepath" "testing" + "github.com/host-uk/core/pkg/io" "github.com/stretchr/testify/assert" ) @@ -13,7 +14,7 @@ func TestDetectServeCommand_Good_Laravel(t *testing.T) { err := os.WriteFile(filepath.Join(tmpDir, "artisan"), []byte("#!/usr/bin/env php"), 0644) assert.NoError(t, err) - cmd := DetectServeCommand(tmpDir) + cmd := DetectServeCommand(io.Local, tmpDir) assert.Equal(t, "php artisan octane:start --host=0.0.0.0 --port=8000", cmd) } @@ -23,7 +24,7 @@ func TestDetectServeCommand_Good_NodeDev(t *testing.T) { err := os.WriteFile(filepath.Join(tmpDir, "package.json"), []byte(packageJSON), 0644) assert.NoError(t, err) - cmd := DetectServeCommand(tmpDir) + cmd := DetectServeCommand(io.Local, tmpDir) assert.Equal(t, "npm run dev -- --host 0.0.0.0", cmd) } @@ -33,7 +34,7 @@ func TestDetectServeCommand_Good_NodeStart(t *testing.T) { err := os.WriteFile(filepath.Join(tmpDir, "package.json"), []byte(packageJSON), 0644) assert.NoError(t, err) - cmd := DetectServeCommand(tmpDir) + cmd := DetectServeCommand(io.Local, tmpDir) assert.Equal(t, "npm start", cmd) } @@ -42,7 +43,7 @@ func TestDetectServeCommand_Good_PHP(t *testing.T) { err := os.WriteFile(filepath.Join(tmpDir, "composer.json"), []byte(`{"require":{}}`), 0644) assert.NoError(t, err) - cmd := DetectServeCommand(tmpDir) + cmd := DetectServeCommand(io.Local, tmpDir) assert.Equal(t, "frankenphp php-server -l :8000", cmd) } @@ -53,7 +54,7 @@ func TestDetectServeCommand_Good_GoMain(t *testing.T) { err = os.WriteFile(filepath.Join(tmpDir, "main.go"), []byte("package main"), 0644) assert.NoError(t, err) - cmd := DetectServeCommand(tmpDir) + cmd := DetectServeCommand(io.Local, tmpDir) assert.Equal(t, "go run .", cmd) } @@ -63,7 +64,7 @@ func TestDetectServeCommand_Good_GoWithoutMain(t *testing.T) { assert.NoError(t, err) // No main.go, so falls through to fallback - cmd := DetectServeCommand(tmpDir) + cmd := DetectServeCommand(io.Local, tmpDir) assert.Equal(t, "python3 -m http.server 8000", cmd) } @@ -72,14 +73,14 @@ func TestDetectServeCommand_Good_Django(t *testing.T) { err := os.WriteFile(filepath.Join(tmpDir, "manage.py"), []byte("#!/usr/bin/env python"), 0644) assert.NoError(t, err) - cmd := DetectServeCommand(tmpDir) + cmd := DetectServeCommand(io.Local, tmpDir) assert.Equal(t, "python manage.py runserver 0.0.0.0:8000", cmd) } func TestDetectServeCommand_Good_Fallback(t *testing.T) { tmpDir := t.TempDir() - cmd := DetectServeCommand(tmpDir) + cmd := DetectServeCommand(io.Local, tmpDir) assert.Equal(t, "python3 -m http.server 8000", cmd) } @@ -91,7 +92,7 @@ func TestDetectServeCommand_Good_Priority(t *testing.T) { err = os.WriteFile(filepath.Join(tmpDir, "composer.json"), []byte(`{"require":{}}`), 0644) assert.NoError(t, err) - cmd := DetectServeCommand(tmpDir) + cmd := DetectServeCommand(io.Local, tmpDir) assert.Equal(t, "php artisan octane:start --host=0.0.0.0 --port=8000", cmd) } @@ -116,13 +117,13 @@ func TestHasFile_Good(t *testing.T) { err := os.WriteFile(testFile, []byte("content"), 0644) assert.NoError(t, err) - assert.True(t, hasFile(tmpDir, "test.txt")) + assert.True(t, hasFile(io.Local, tmpDir, "test.txt")) } func TestHasFile_Bad(t *testing.T) { tmpDir := t.TempDir() - assert.False(t, hasFile(tmpDir, "nonexistent.txt")) + assert.False(t, hasFile(io.Local, tmpDir, "nonexistent.txt")) } func TestHasFile_Bad_Directory(t *testing.T) { @@ -132,5 +133,5 @@ func TestHasFile_Bad_Directory(t *testing.T) { assert.NoError(t, err) // hasFile correctly returns false for directories (only true for regular files) - assert.False(t, hasFile(tmpDir, "subdir")) + assert.False(t, hasFile(io.Local, tmpDir, "subdir")) } diff --git a/pkg/devops/sources/cdn.go b/pkg/devops/sources/cdn.go index 41269624..8408cf88 100644 --- a/pkg/devops/sources/cdn.go +++ b/pkg/devops/sources/cdn.go @@ -54,7 +54,7 @@ func (s *CDNSource) LatestVersion(ctx context.Context) (string, error) { } // Download downloads the image from CDN. -func (s *CDNSource) Download(ctx context.Context, dest string, progress func(downloaded, total int64)) error { +func (s *CDNSource) Download(ctx context.Context, m io.Medium, dest string, progress func(downloaded, total int64)) error { url := fmt.Sprintf("%s/%s", s.config.CDNURL, s.config.ImageName) req, err := http.NewRequestWithContext(ctx, "GET", url, nil) @@ -73,7 +73,7 @@ func (s *CDNSource) Download(ctx context.Context, dest string, progress func(dow } // Ensure dest directory exists - if err := io.Local.EnsureDir(dest); err != nil { + if err := m.EnsureDir(dest); err != nil { return fmt.Errorf("cdn.Download: %w", err) } diff --git a/pkg/devops/sources/cdn_test.go b/pkg/devops/sources/cdn_test.go index de9c9639..2fe33c85 100644 --- a/pkg/devops/sources/cdn_test.go +++ b/pkg/devops/sources/cdn_test.go @@ -9,6 +9,7 @@ import ( "path/filepath" "testing" + "github.com/host-uk/core/pkg/io" "github.com/stretchr/testify/assert" ) @@ -71,7 +72,7 @@ func TestCDNSource_Download_Good(t *testing.T) { }) var progressCalled bool - err := src.Download(context.Background(), dest, func(downloaded, total int64) { + err := src.Download(context.Background(), io.Local, dest, func(downloaded, total int64) { progressCalled = true }) @@ -97,7 +98,7 @@ func TestCDNSource_Download_Bad(t *testing.T) { ImageName: "test.img", }) - err := src.Download(context.Background(), dest, nil) + err := src.Download(context.Background(), io.Local, dest, nil) assert.Error(t, err) assert.Contains(t, err.Error(), "HTTP 500") }) @@ -109,7 +110,7 @@ func TestCDNSource_Download_Bad(t *testing.T) { ImageName: "test.img", }) - err := src.Download(context.Background(), dest, nil) + err := src.Download(context.Background(), io.Local, dest, nil) assert.Error(t, err) }) } @@ -162,7 +163,7 @@ func TestCDNSource_Download_Good_NoProgress(t *testing.T) { }) // nil progress callback should be handled gracefully - err := src.Download(context.Background(), dest, nil) + err := src.Download(context.Background(), io.Local, dest, nil) assert.NoError(t, err) data, err := os.ReadFile(filepath.Join(dest, "test.img")) @@ -192,7 +193,7 @@ func TestCDNSource_Download_Good_LargeFile(t *testing.T) { var progressCalls int var lastDownloaded int64 - err := src.Download(context.Background(), dest, func(downloaded, total int64) { + err := src.Download(context.Background(), io.Local, dest, func(downloaded, total int64) { progressCalls++ lastDownloaded = downloaded }) @@ -227,7 +228,7 @@ func TestCDNSource_Download_Bad_HTTPErrorCodes(t *testing.T) { ImageName: "test.img", }) - err := src.Download(context.Background(), dest, nil) + err := src.Download(context.Background(), io.Local, dest, nil) assert.Error(t, err) assert.Contains(t, err.Error(), fmt.Sprintf("HTTP %d", tc.statusCode)) }) @@ -281,7 +282,7 @@ func TestCDNSource_Download_Good_CreatesDestDir(t *testing.T) { ImageName: "test.img", }) - err := src.Download(context.Background(), dest, nil) + err := src.Download(context.Background(), io.Local, dest, nil) assert.NoError(t, err) // Verify nested dir was created diff --git a/pkg/devops/sources/github.go b/pkg/devops/sources/github.go index 98a86b67..323f2dda 100644 --- a/pkg/devops/sources/github.go +++ b/pkg/devops/sources/github.go @@ -6,6 +6,8 @@ import ( "os" "os/exec" "strings" + + "github.com/host-uk/core/pkg/io" ) // GitHubSource downloads images from GitHub Releases. @@ -52,7 +54,7 @@ func (s *GitHubSource) LatestVersion(ctx context.Context) (string, error) { } // Download downloads the image from the latest release. -func (s *GitHubSource) Download(ctx context.Context, dest string, progress func(downloaded, total int64)) error { +func (s *GitHubSource) Download(ctx context.Context, m io.Medium, dest string, progress func(downloaded, total int64)) error { // Get release assets to find our image cmd := exec.CommandContext(ctx, "gh", "release", "download", "-R", s.config.GitHubRepo, diff --git a/pkg/devops/sources/source.go b/pkg/devops/sources/source.go index 94e4ff68..f5ca4460 100644 --- a/pkg/devops/sources/source.go +++ b/pkg/devops/sources/source.go @@ -3,6 +3,8 @@ package sources import ( "context" + + "github.com/host-uk/core/pkg/io" ) // ImageSource defines the interface for downloading dev images. @@ -15,7 +17,7 @@ type ImageSource interface { LatestVersion(ctx context.Context) (string, error) // Download downloads the image to the destination path. // Reports progress via the callback if provided. - Download(ctx context.Context, dest string, progress func(downloaded, total int64)) error + Download(ctx context.Context, m io.Medium, dest string, progress func(downloaded, total int64)) error } // SourceConfig holds configuration for a source. diff --git a/pkg/devops/test.go b/pkg/devops/test.go index e424472e..89d1726c 100644 --- a/pkg/devops/test.go +++ b/pkg/devops/test.go @@ -47,7 +47,7 @@ func (d *DevOps) Test(ctx context.Context, projectDir string, opts TestOptions) if len(opts.Command) > 0 { cmd = strings.Join(opts.Command, " ") } else if opts.Name != "" { - cfg, err := LoadTestConfig(projectDir) + cfg, err := LoadTestConfig(d.medium, projectDir) if err != nil { return err } @@ -61,7 +61,7 @@ func (d *DevOps) Test(ctx context.Context, projectDir string, opts TestOptions) return fmt.Errorf("test command %q not found in .core/test.yaml", opts.Name) } } else { - cmd = DetectTestCommand(projectDir) + cmd = DetectTestCommand(d.medium, projectDir) if cmd == "" { return fmt.Errorf("could not detect test command (create .core/test.yaml)") } @@ -72,39 +72,39 @@ func (d *DevOps) Test(ctx context.Context, projectDir string, opts TestOptions) } // DetectTestCommand auto-detects the test command for a project. -func DetectTestCommand(projectDir string) string { +func DetectTestCommand(m io.Medium, projectDir string) string { // 1. Check .core/test.yaml - cfg, err := LoadTestConfig(projectDir) + cfg, err := LoadTestConfig(m, projectDir) if err == nil && cfg.Command != "" { return cfg.Command } // 2. Check composer.json for test script - if hasFile(projectDir, "composer.json") { - if hasComposerScript(projectDir, "test") { + if hasFile(m, projectDir, "composer.json") { + if hasComposerScript(m, projectDir, "test") { return "composer test" } } // 3. Check package.json for test script - if hasFile(projectDir, "package.json") { - if hasPackageScript(projectDir, "test") { + if hasFile(m, projectDir, "package.json") { + if hasPackageScript(m, projectDir, "test") { return "npm test" } } // 4. Check go.mod - if hasFile(projectDir, "go.mod") { + if hasFile(m, projectDir, "go.mod") { return "go test ./..." } // 5. Check pytest - if hasFile(projectDir, "pytest.ini") || hasFile(projectDir, "pyproject.toml") { + if hasFile(m, projectDir, "pytest.ini") || hasFile(m, projectDir, "pyproject.toml") { return "pytest" } // 6. Check Taskfile - if hasFile(projectDir, "Taskfile.yaml") || hasFile(projectDir, "Taskfile.yml") { + if hasFile(m, projectDir, "Taskfile.yaml") || hasFile(m, projectDir, "Taskfile.yml") { return "task test" } @@ -112,14 +112,14 @@ func DetectTestCommand(projectDir string) string { } // LoadTestConfig loads .core/test.yaml. -func LoadTestConfig(projectDir string) (*TestConfig, error) { +func LoadTestConfig(m io.Medium, projectDir string) (*TestConfig, error) { path := filepath.Join(projectDir, ".core", "test.yaml") absPath, err := filepath.Abs(path) if err != nil { return nil, err } - content, err := io.Local.Read(absPath) + content, err := m.Read(absPath) if err != nil { return nil, err } @@ -132,23 +132,23 @@ func LoadTestConfig(projectDir string) (*TestConfig, error) { return &cfg, nil } -func hasFile(dir, name string) bool { +func hasFile(m io.Medium, dir, name string) bool { path := filepath.Join(dir, name) absPath, err := filepath.Abs(path) if err != nil { return false } - return io.Local.IsFile(absPath) + return m.IsFile(absPath) } -func hasPackageScript(projectDir, script string) bool { +func hasPackageScript(m io.Medium, projectDir, script string) bool { path := filepath.Join(projectDir, "package.json") absPath, err := filepath.Abs(path) if err != nil { return false } - content, err := io.Local.Read(absPath) + content, err := m.Read(absPath) if err != nil { return false } @@ -164,14 +164,14 @@ func hasPackageScript(projectDir, script string) bool { return ok } -func hasComposerScript(projectDir, script string) bool { +func hasComposerScript(m io.Medium, projectDir, script string) bool { path := filepath.Join(projectDir, "composer.json") absPath, err := filepath.Abs(path) if err != nil { return false } - content, err := io.Local.Read(absPath) + content, err := m.Read(absPath) if err != nil { return false } diff --git a/pkg/devops/test_test.go b/pkg/devops/test_test.go index 2a20e6e2..8f4cff77 100644 --- a/pkg/devops/test_test.go +++ b/pkg/devops/test_test.go @@ -4,13 +4,15 @@ import ( "os" "path/filepath" "testing" + + "github.com/host-uk/core/pkg/io" ) func TestDetectTestCommand_Good_ComposerJSON(t *testing.T) { tmpDir := t.TempDir() _ = os.WriteFile(filepath.Join(tmpDir, "composer.json"), []byte(`{"scripts":{"test":"pest"}}`), 0644) - cmd := DetectTestCommand(tmpDir) + cmd := DetectTestCommand(io.Local, tmpDir) if cmd != "composer test" { t.Errorf("expected 'composer test', got %q", cmd) } @@ -20,7 +22,7 @@ func TestDetectTestCommand_Good_PackageJSON(t *testing.T) { tmpDir := t.TempDir() _ = os.WriteFile(filepath.Join(tmpDir, "package.json"), []byte(`{"scripts":{"test":"vitest"}}`), 0644) - cmd := DetectTestCommand(tmpDir) + cmd := DetectTestCommand(io.Local, tmpDir) if cmd != "npm test" { t.Errorf("expected 'npm test', got %q", cmd) } @@ -30,7 +32,7 @@ func TestDetectTestCommand_Good_GoMod(t *testing.T) { tmpDir := t.TempDir() _ = os.WriteFile(filepath.Join(tmpDir, "go.mod"), []byte("module example"), 0644) - cmd := DetectTestCommand(tmpDir) + cmd := DetectTestCommand(io.Local, tmpDir) if cmd != "go test ./..." { t.Errorf("expected 'go test ./...', got %q", cmd) } @@ -42,7 +44,7 @@ func TestDetectTestCommand_Good_CoreTestYaml(t *testing.T) { _ = os.MkdirAll(coreDir, 0755) _ = os.WriteFile(filepath.Join(coreDir, "test.yaml"), []byte("command: custom-test"), 0644) - cmd := DetectTestCommand(tmpDir) + cmd := DetectTestCommand(io.Local, tmpDir) if cmd != "custom-test" { t.Errorf("expected 'custom-test', got %q", cmd) } @@ -52,7 +54,7 @@ func TestDetectTestCommand_Good_Pytest(t *testing.T) { tmpDir := t.TempDir() _ = os.WriteFile(filepath.Join(tmpDir, "pytest.ini"), []byte("[pytest]"), 0644) - cmd := DetectTestCommand(tmpDir) + cmd := DetectTestCommand(io.Local, tmpDir) if cmd != "pytest" { t.Errorf("expected 'pytest', got %q", cmd) } @@ -62,7 +64,7 @@ func TestDetectTestCommand_Good_Taskfile(t *testing.T) { tmpDir := t.TempDir() _ = os.WriteFile(filepath.Join(tmpDir, "Taskfile.yaml"), []byte("version: '3'"), 0644) - cmd := DetectTestCommand(tmpDir) + cmd := DetectTestCommand(io.Local, tmpDir) if cmd != "task test" { t.Errorf("expected 'task test', got %q", cmd) } @@ -71,7 +73,7 @@ func TestDetectTestCommand_Good_Taskfile(t *testing.T) { func TestDetectTestCommand_Bad_NoFiles(t *testing.T) { tmpDir := t.TempDir() - cmd := DetectTestCommand(tmpDir) + cmd := DetectTestCommand(io.Local, tmpDir) if cmd != "" { t.Errorf("expected empty string, got %q", cmd) } @@ -85,7 +87,7 @@ func TestDetectTestCommand_Good_Priority(t *testing.T) { _ = os.WriteFile(filepath.Join(coreDir, "test.yaml"), []byte("command: my-custom-test"), 0644) _ = os.WriteFile(filepath.Join(tmpDir, "go.mod"), []byte("module example"), 0644) - cmd := DetectTestCommand(tmpDir) + cmd := DetectTestCommand(io.Local, tmpDir) if cmd != "my-custom-test" { t.Errorf("expected 'my-custom-test' (from .core/test.yaml), got %q", cmd) } @@ -108,7 +110,7 @@ env: ` _ = os.WriteFile(filepath.Join(coreDir, "test.yaml"), []byte(configYAML), 0644) - cfg, err := LoadTestConfig(tmpDir) + cfg, err := LoadTestConfig(io.Local, tmpDir) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -133,7 +135,7 @@ env: func TestLoadTestConfig_Bad_NotFound(t *testing.T) { tmpDir := t.TempDir() - _, err := LoadTestConfig(tmpDir) + _, err := LoadTestConfig(io.Local, tmpDir) if err == nil { t.Error("expected error for missing config, got nil") } @@ -143,10 +145,10 @@ func TestHasPackageScript_Good(t *testing.T) { tmpDir := t.TempDir() _ = os.WriteFile(filepath.Join(tmpDir, "package.json"), []byte(`{"scripts":{"test":"jest","build":"webpack"}}`), 0644) - if !hasPackageScript(tmpDir, "test") { + if !hasPackageScript(io.Local, tmpDir, "test") { t.Error("expected to find 'test' script") } - if !hasPackageScript(tmpDir, "build") { + if !hasPackageScript(io.Local, tmpDir, "build") { t.Error("expected to find 'build' script") } } @@ -155,7 +157,7 @@ func TestHasPackageScript_Bad_MissingScript(t *testing.T) { tmpDir := t.TempDir() _ = os.WriteFile(filepath.Join(tmpDir, "package.json"), []byte(`{"scripts":{"build":"webpack"}}`), 0644) - if hasPackageScript(tmpDir, "test") { + if hasPackageScript(io.Local, tmpDir, "test") { t.Error("expected not to find 'test' script") } } @@ -164,7 +166,7 @@ func TestHasComposerScript_Good(t *testing.T) { tmpDir := t.TempDir() _ = os.WriteFile(filepath.Join(tmpDir, "composer.json"), []byte(`{"scripts":{"test":"pest","post-install-cmd":"@php artisan migrate"}}`), 0644) - if !hasComposerScript(tmpDir, "test") { + if !hasComposerScript(io.Local, tmpDir, "test") { t.Error("expected to find 'test' script") } } @@ -173,7 +175,7 @@ func TestHasComposerScript_Bad_MissingScript(t *testing.T) { tmpDir := t.TempDir() _ = os.WriteFile(filepath.Join(tmpDir, "composer.json"), []byte(`{"scripts":{"build":"@php build.php"}}`), 0644) - if hasComposerScript(tmpDir, "test") { + if hasComposerScript(io.Local, tmpDir, "test") { t.Error("expected not to find 'test' script") } } @@ -229,7 +231,7 @@ func TestDetectTestCommand_Good_TaskfileYml(t *testing.T) { tmpDir := t.TempDir() _ = os.WriteFile(filepath.Join(tmpDir, "Taskfile.yml"), []byte("version: '3'"), 0644) - cmd := DetectTestCommand(tmpDir) + cmd := DetectTestCommand(io.Local, tmpDir) if cmd != "task test" { t.Errorf("expected 'task test', got %q", cmd) } @@ -239,7 +241,7 @@ func TestDetectTestCommand_Good_Pyproject(t *testing.T) { tmpDir := t.TempDir() _ = os.WriteFile(filepath.Join(tmpDir, "pyproject.toml"), []byte("[tool.pytest]"), 0644) - cmd := DetectTestCommand(tmpDir) + cmd := DetectTestCommand(io.Local, tmpDir) if cmd != "pytest" { t.Errorf("expected 'pytest', got %q", cmd) } @@ -248,7 +250,7 @@ func TestDetectTestCommand_Good_Pyproject(t *testing.T) { func TestHasPackageScript_Bad_NoFile(t *testing.T) { tmpDir := t.TempDir() - if hasPackageScript(tmpDir, "test") { + if hasPackageScript(io.Local, tmpDir, "test") { t.Error("expected false for missing package.json") } } @@ -257,7 +259,7 @@ func TestHasPackageScript_Bad_InvalidJSON(t *testing.T) { tmpDir := t.TempDir() _ = os.WriteFile(filepath.Join(tmpDir, "package.json"), []byte(`invalid json`), 0644) - if hasPackageScript(tmpDir, "test") { + if hasPackageScript(io.Local, tmpDir, "test") { t.Error("expected false for invalid JSON") } } @@ -266,7 +268,7 @@ func TestHasPackageScript_Bad_NoScripts(t *testing.T) { tmpDir := t.TempDir() _ = os.WriteFile(filepath.Join(tmpDir, "package.json"), []byte(`{"name":"test"}`), 0644) - if hasPackageScript(tmpDir, "test") { + if hasPackageScript(io.Local, tmpDir, "test") { t.Error("expected false for missing scripts section") } } @@ -274,7 +276,7 @@ func TestHasPackageScript_Bad_NoScripts(t *testing.T) { func TestHasComposerScript_Bad_NoFile(t *testing.T) { tmpDir := t.TempDir() - if hasComposerScript(tmpDir, "test") { + if hasComposerScript(io.Local, tmpDir, "test") { t.Error("expected false for missing composer.json") } } @@ -283,7 +285,7 @@ func TestHasComposerScript_Bad_InvalidJSON(t *testing.T) { tmpDir := t.TempDir() _ = os.WriteFile(filepath.Join(tmpDir, "composer.json"), []byte(`invalid json`), 0644) - if hasComposerScript(tmpDir, "test") { + if hasComposerScript(io.Local, tmpDir, "test") { t.Error("expected false for invalid JSON") } } @@ -292,7 +294,7 @@ func TestHasComposerScript_Bad_NoScripts(t *testing.T) { tmpDir := t.TempDir() _ = os.WriteFile(filepath.Join(tmpDir, "composer.json"), []byte(`{"name":"test/pkg"}`), 0644) - if hasComposerScript(tmpDir, "test") { + if hasComposerScript(io.Local, tmpDir, "test") { t.Error("expected false for missing scripts section") } } @@ -303,7 +305,7 @@ func TestLoadTestConfig_Bad_InvalidYAML(t *testing.T) { _ = os.MkdirAll(coreDir, 0755) _ = os.WriteFile(filepath.Join(coreDir, "test.yaml"), []byte("invalid: yaml: :"), 0644) - _, err := LoadTestConfig(tmpDir) + _, err := LoadTestConfig(io.Local, tmpDir) if err == nil { t.Error("expected error for invalid YAML") } @@ -315,7 +317,7 @@ func TestLoadTestConfig_Good_MinimalConfig(t *testing.T) { _ = os.MkdirAll(coreDir, 0755) _ = os.WriteFile(filepath.Join(coreDir, "test.yaml"), []byte("version: 1"), 0644) - cfg, err := LoadTestConfig(tmpDir) + cfg, err := LoadTestConfig(io.Local, tmpDir) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -332,7 +334,7 @@ func TestDetectTestCommand_Good_ComposerWithoutScript(t *testing.T) { // composer.json without test script should not return composer test _ = os.WriteFile(filepath.Join(tmpDir, "composer.json"), []byte(`{"name":"test/pkg"}`), 0644) - cmd := DetectTestCommand(tmpDir) + cmd := DetectTestCommand(io.Local, tmpDir) // Falls through to empty (no match) if cmd != "" { t.Errorf("expected empty string, got %q", cmd) @@ -344,7 +346,7 @@ func TestDetectTestCommand_Good_PackageJSONWithoutScript(t *testing.T) { // package.json without test or dev script _ = os.WriteFile(filepath.Join(tmpDir, "package.json"), []byte(`{"name":"test"}`), 0644) - cmd := DetectTestCommand(tmpDir) + cmd := DetectTestCommand(io.Local, tmpDir) // Falls through to empty if cmd != "" { t.Errorf("expected empty string, got %q", cmd) diff --git a/pkg/io/io.go b/pkg/io/io.go index 7327b303..36b907c6 100644 --- a/pkg/io/io.go +++ b/pkg/io/io.go @@ -1,7 +1,7 @@ package io import ( - "io" + goio "io" "io/fs" "os" "path/filepath" @@ -53,7 +53,7 @@ type Medium interface { Open(path string) (fs.File, error) // Create creates or truncates the named file. - Create(path string) (io.WriteCloser, error) + Create(path string) (goio.WriteCloser, error) // Exists checks if a path exists (file or directory). Exists(path string) bool @@ -327,7 +327,7 @@ func (m *MockMedium) Open(path string) (fs.File, error) { } // Create creates a file in the mock filesystem. -func (m *MockMedium) Create(path string) (io.WriteCloser, error) { +func (m *MockMedium) Create(path string) (goio.WriteCloser, error) { return &MockWriteCloser{ medium: m, path: path, @@ -350,7 +350,7 @@ func (f *MockFile) Stat() (fs.FileInfo, error) { func (f *MockFile) Read(b []byte) (int, error) { if f.offset >= int64(len(f.content)) { - return 0, io.EOF + return 0, goio.EOF } n := copy(b, f.content[f.offset:]) f.offset += int64(n) diff --git a/pkg/io/local/client.go b/pkg/io/local/client.go index 8a8b5154..452afad3 100644 --- a/pkg/io/local/client.go +++ b/pkg/io/local/client.go @@ -2,7 +2,7 @@ package local import ( - "io" + goio "io" "io/fs" "os" "path/filepath" @@ -25,42 +25,79 @@ func New(root string) (*Medium, error) { } // path sanitizes and returns the full path. -// Replaces .. with . to prevent traversal, then joins with root. // Absolute paths are sandboxed under root (unless root is "/"). func (m *Medium) path(p string) string { if p == "" { return m.root } - clean := strings.ReplaceAll(p, "..", ".") - if filepath.IsAbs(clean) { - // If root is "/", allow absolute paths through - if m.root == "/" { - return filepath.Clean(clean) - } - // Otherwise, sandbox absolute paths by stripping volume + leading separators - vol := filepath.VolumeName(clean) - clean = strings.TrimPrefix(clean, vol) - cutset := string(os.PathSeparator) - if os.PathSeparator != '/' { - cutset += "/" - } - clean = strings.TrimLeft(clean, cutset) - return filepath.Join(m.root, clean) - } + // If the path is relative and the medium is rooted at "/", // treat it as relative to the current working directory. // This makes io.Local behave more like the standard 'os' package. - if m.root == "/" && !filepath.IsAbs(clean) { + if m.root == "/" && !filepath.IsAbs(p) { cwd, _ := os.Getwd() - return filepath.Join(cwd, clean) + return filepath.Join(cwd, p) } + // Use filepath.Clean with a leading slash to resolve all .. and . internally + // before joining with the root. This is a standard way to sandbox paths. + clean := filepath.Clean("/" + p) + + // If root is "/", allow absolute paths through + if m.root == "/" { + return clean + } + + // Join cleaned relative path with root return filepath.Join(m.root, clean) } +// validatePath ensures the path is within the sandbox, following symlinks if they exist. +func (m *Medium) validatePath(p string) (string, error) { + if m.root == "/" { + return m.path(p), nil + } + + // Split the cleaned path into components + parts := strings.Split(filepath.Clean("/"+p), string(os.PathSeparator)) + current := m.root + + for _, part := range parts { + if part == "" { + continue + } + + next := filepath.Join(current, part) + realNext, err := filepath.EvalSymlinks(next) + if err != nil { + if os.IsNotExist(err) { + // Part doesn't exist, we can't follow symlinks anymore. + // Since the path is already Cleaned and current is safe, + // appending a component to current will not escape. + current = next + continue + } + return "", err + } + + // Verify the resolved part is still within the root + rel, err := filepath.Rel(m.root, realNext) + if err != nil || strings.HasPrefix(rel, "..") { + return "", os.ErrPermission // Path escapes sandbox + } + current = realNext + } + + return current, nil +} + // Read returns file contents as string. func (m *Medium) Read(p string) (string, error) { - data, err := os.ReadFile(m.path(p)) + full, err := m.validatePath(p) + if err != nil { + return "", err + } + data, err := os.ReadFile(full) if err != nil { return "", err } @@ -69,7 +106,10 @@ func (m *Medium) Read(p string) (string, error) { // Write saves content to file, creating parent directories as needed. func (m *Medium) Write(p, content string) error { - full := m.path(p) + full, err := m.validatePath(p) + if err != nil { + return err + } if err := os.MkdirAll(filepath.Dir(full), 0755); err != nil { return err } @@ -78,7 +118,11 @@ func (m *Medium) Write(p, content string) error { // EnsureDir creates directory if it doesn't exist. func (m *Medium) EnsureDir(p string) error { - return os.MkdirAll(m.path(p), 0755) + full, err := m.validatePath(p) + if err != nil { + return err + } + return os.MkdirAll(full, 0755) } // IsDir returns true if path is a directory. @@ -86,7 +130,11 @@ func (m *Medium) IsDir(p string) bool { if p == "" { return false } - info, err := os.Stat(m.path(p)) + full, err := m.validatePath(p) + if err != nil { + return false + } + info, err := os.Stat(full) return err == nil && info.IsDir() } @@ -95,34 +143,57 @@ func (m *Medium) IsFile(p string) bool { if p == "" { return false } - info, err := os.Stat(m.path(p)) + full, err := m.validatePath(p) + if err != nil { + return false + } + info, err := os.Stat(full) return err == nil && info.Mode().IsRegular() } // Exists returns true if path exists. func (m *Medium) Exists(p string) bool { - _, err := os.Stat(m.path(p)) + full, err := m.validatePath(p) + if err != nil { + return false + } + _, err = os.Stat(full) return err == nil } // List returns directory entries. func (m *Medium) List(p string) ([]fs.DirEntry, error) { - return os.ReadDir(m.path(p)) + full, err := m.validatePath(p) + if err != nil { + return nil, err + } + return os.ReadDir(full) } // Stat returns file info. func (m *Medium) Stat(p string) (fs.FileInfo, error) { - return os.Stat(m.path(p)) + full, err := m.validatePath(p) + if err != nil { + return nil, err + } + return os.Stat(full) } // Open opens the named file for reading. func (m *Medium) Open(p string) (fs.File, error) { - return os.Open(m.path(p)) + full, err := m.validatePath(p) + if err != nil { + return nil, err + } + return os.Open(full) } // Create creates or truncates the named file. -func (m *Medium) Create(p string) (io.WriteCloser, error) { - full := m.path(p) +func (m *Medium) Create(p string) (goio.WriteCloser, error) { + full, err := m.validatePath(p) + if err != nil { + return nil, err + } if err := os.MkdirAll(filepath.Dir(full), 0755); err != nil { return nil, err } @@ -131,7 +202,10 @@ func (m *Medium) Create(p string) (io.WriteCloser, error) { // Delete removes a file or empty directory. func (m *Medium) Delete(p string) error { - full := m.path(p) + full, err := m.validatePath(p) + if err != nil { + return err + } if len(full) < 3 { return nil } @@ -140,7 +214,10 @@ func (m *Medium) Delete(p string) error { // DeleteAll removes a file or directory recursively. func (m *Medium) DeleteAll(p string) error { - full := m.path(p) + full, err := m.validatePath(p) + if err != nil { + return err + } if len(full) < 3 { return nil } @@ -149,7 +226,15 @@ func (m *Medium) DeleteAll(p string) error { // Rename moves a file or directory. func (m *Medium) Rename(oldPath, newPath string) error { - return os.Rename(m.path(oldPath), m.path(newPath)) + oldFull, err := m.validatePath(oldPath) + if err != nil { + return err + } + newFull, err := m.validatePath(newPath) + if err != nil { + return err + } + return os.Rename(oldFull, newFull) } // FileGet is an alias for Read. diff --git a/pkg/io/local/client_test.go b/pkg/io/local/client_test.go index 3cb5996a..7471174c 100644 --- a/pkg/io/local/client_test.go +++ b/pkg/io/local/client_test.go @@ -25,9 +25,9 @@ func TestPath(t *testing.T) { // Empty returns root assert.Equal(t, "/home/user", m.path("")) - // Traversal attempts get sanitized (.. becomes ., then cleaned by Join) + // Traversal attempts get sanitized assert.Equal(t, "/home/user/file.txt", m.path("../file.txt")) - assert.Equal(t, "/home/user/dir/file.txt", m.path("dir/../file.txt")) + assert.Equal(t, "/home/user/file.txt", m.path("dir/../file.txt")) // Absolute paths are constrained to sandbox (no escape) assert.Equal(t, "/home/user/etc/passwd", m.path("/etc/passwd")) diff --git a/pkg/mcp/mcp.go b/pkg/mcp/mcp.go index 9f07dbc8..0d3dba0d 100644 --- a/pkg/mcp/mcp.go +++ b/pkg/mcp/mcp.go @@ -10,6 +10,7 @@ import ( "strings" "github.com/host-uk/core/pkg/io" + "github.com/host-uk/core/pkg/io/local" "github.com/modelcontextprotocol/go-sdk/mcp" ) @@ -40,7 +41,7 @@ func WithWorkspaceRoot(root string) Option { if err != nil { return fmt.Errorf("invalid workspace root: %w", err) } - m, err := io.NewSandboxed(abs) + m, err := local.New(abs) if err != nil { return fmt.Errorf("failed to create workspace medium: %w", err) } @@ -69,7 +70,7 @@ func New(opts ...Option) (*Service, error) { return nil, fmt.Errorf("failed to get working directory: %w", err) } s.workspaceRoot = cwd - m, err := io.NewSandboxed(cwd) + m, err := local.New(cwd) if err != nil { return nil, fmt.Errorf("failed to create sandboxed medium: %w", err) } @@ -310,11 +311,8 @@ func (s *Service) listDirectory(ctx context.Context, req *mcp.CallToolRequest, i size = info.Size() } result = append(result, DirectoryEntry{ - Name: e.Name(), - Path: filepath.Join(input.Path, e.Name()), // Note: This might be relative path, client might expect absolute? - // Issue 103 says "Replace ... with local.Medium sandboxing". - // Previous code returned `filepath.Join(input.Path, e.Name())`. - // If input.Path is relative, this preserves it. + Name: e.Name(), + Path: filepath.Join(input.Path, e.Name()), IsDir: e.IsDir(), Size: size, }) @@ -344,21 +342,18 @@ func (s *Service) renameFile(ctx context.Context, req *mcp.CallToolRequest, inpu } func (s *Service) fileExists(ctx context.Context, req *mcp.CallToolRequest, input FileExistsInput) (*mcp.CallToolResult, FileExistsOutput, error) { - exists := s.medium.IsFile(input.Path) - if exists { - return nil, FileExistsOutput{Exists: true, IsDir: false, Path: input.Path}, nil + info, err := s.medium.Stat(input.Path) + if err != nil { + // Any error from Stat (e.g., not found, permission denied) is treated as "does not exist" + // for the purpose of this tool. + return nil, FileExistsOutput{Exists: false, IsDir: false, Path: input.Path}, nil } - // Check if it's a directory by attempting to list it - // List might fail if it's a file too (but we checked IsFile) or if doesn't exist. - _, err := s.medium.List(input.Path) - isDir := err == nil - // If List failed, it might mean it doesn't exist OR it's a special file or permissions. - // Assuming if List works, it's a directory. - - // Refinement: If it doesn't exist, List returns error. - - return nil, FileExistsOutput{Exists: isDir, IsDir: isDir, Path: input.Path}, nil + return nil, FileExistsOutput{ + Exists: true, + IsDir: info.IsDir(), + Path: input.Path, + }, nil } func (s *Service) detectLanguage(ctx context.Context, req *mcp.CallToolRequest, input DetectLanguageInput) (*mcp.CallToolResult, DetectLanguageOutput, error) { diff --git a/pkg/mcp/mcp_test.go b/pkg/mcp/mcp_test.go index 544d2da2..2172abda 100644 --- a/pkg/mcp/mcp_test.go +++ b/pkg/mcp/mcp_test.go @@ -144,12 +144,15 @@ func TestSandboxing_Traversal_Sanitized(t *testing.T) { t.Error("Expected error (file not found)") } - // Absolute paths are allowed through - they access the real filesystem. - // This is intentional for full filesystem access. Callers wanting sandboxing - // should validate inputs before calling Medium. + // Absolute paths are also sandboxed under the root directory. + // For example, /etc/passwd becomes /etc/passwd. + _, err = s.medium.Read("/etc/passwd") + if err == nil { + t.Error("Expected error (file not found in sandbox)") + } } -func TestSandboxing_Symlinks_Followed(t *testing.T) { +func TestSandboxing_Symlinks_Blocked(t *testing.T) { tmpDir := t.TempDir() outsideDir := t.TempDir() @@ -170,14 +173,15 @@ func TestSandboxing_Symlinks_Followed(t *testing.T) { t.Fatalf("Failed to create service: %v", err) } - // Symlinks are followed - no traversal blocking at Medium level. - // This is intentional for simplicity. Callers wanting to block symlinks - // should validate inputs before calling Medium. - content, err := s.medium.Read("link") - if err != nil { - t.Errorf("Expected symlink to be followed, got error: %v", err) + // Symlinks that escape the sandbox should be blocked. + _, err = s.medium.Read("link") + if err == nil { + t.Error("Expected error for symlink escaping sandbox, got nil") } - if content != "secret" { - t.Errorf("Expected 'secret', got '%s'", content) + + // Symlinks that escape the sandbox should be blocked even if target doesn't exist. + _, err = s.medium.Read("link/nonexistent") + if err == nil { + t.Error("Expected error for symlink/nonexistent escaping sandbox, got nil") } } diff --git a/pkg/release/config.go b/pkg/release/config.go index 2beefbf5..2f4d934e 100644 --- a/pkg/release/config.go +++ b/pkg/release/config.go @@ -169,14 +169,14 @@ type ChangelogConfig struct { // LoadConfig loads release configuration from the .core/release.yaml file in the given directory. // If the config file does not exist, it returns DefaultConfig(). // Returns an error if the file exists but cannot be parsed. -func LoadConfig(dir string) (*Config, error) { +func LoadConfig(m io.Medium, dir string) (*Config, error) { configPath := filepath.Join(dir, ConfigDir, ConfigFileName) absPath, err := filepath.Abs(configPath) if err != nil { return nil, fmt.Errorf("release.LoadConfig: failed to resolve path: %w", err) } - content, err := io.Local.Read(absPath) + content, err := m.Read(absPath) if err != nil { if os.IsNotExist(err) { cfg := DefaultConfig() @@ -266,13 +266,13 @@ func ConfigPath(dir string) string { } // ConfigExists checks if a release config file exists in the given directory. -func ConfigExists(dir string) bool { +func ConfigExists(m io.Medium, dir string) bool { configPath := ConfigPath(dir) absPath, err := filepath.Abs(configPath) if err != nil { return false } - return io.Local.IsFile(absPath) + return m.IsFile(absPath) } // GetRepository returns the repository from the config. @@ -286,7 +286,7 @@ func (c *Config) GetProjectName() string { } // WriteConfig writes the config to the .core/release.yaml file. -func WriteConfig(cfg *Config, dir string) error { +func WriteConfig(m io.Medium, cfg *Config, dir string) error { configPath := ConfigPath(dir) absPath, err := filepath.Abs(configPath) if err != nil { @@ -298,8 +298,8 @@ func WriteConfig(cfg *Config, dir string) error { return fmt.Errorf("release.WriteConfig: failed to marshal config: %w", err) } - // io.Local.Write creates parent directories automatically - if err := io.Local.Write(absPath, string(data)); err != nil { + // m.Write creates parent directories automatically + if err := m.Write(absPath, string(data)); err != nil { return fmt.Errorf("release.WriteConfig: failed to write config file: %w", err) } diff --git a/pkg/release/config_test.go b/pkg/release/config_test.go index 24fe1343..7af80e97 100644 --- a/pkg/release/config_test.go +++ b/pkg/release/config_test.go @@ -5,6 +5,7 @@ import ( "path/filepath" "testing" + "github.com/host-uk/core/pkg/io" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -53,7 +54,7 @@ changelog: ` dir := setupConfigTestDir(t, content) - cfg, err := LoadConfig(dir) + cfg, err := LoadConfig(io.Local, dir) require.NoError(t, err) require.NotNil(t, cfg) @@ -76,7 +77,7 @@ changelog: t.Run("returns defaults when config file missing", func(t *testing.T) { dir := t.TempDir() - cfg, err := LoadConfig(dir) + cfg, err := LoadConfig(io.Local, dir) require.NoError(t, err) require.NotNil(t, cfg) @@ -96,7 +97,7 @@ project: ` dir := setupConfigTestDir(t, content) - cfg, err := LoadConfig(dir) + cfg, err := LoadConfig(io.Local, dir) require.NoError(t, err) require.NotNil(t, cfg) @@ -113,7 +114,7 @@ project: t.Run("sets project directory on load", func(t *testing.T) { dir := setupConfigTestDir(t, "version: 1") - cfg, err := LoadConfig(dir) + cfg, err := LoadConfig(io.Local, dir) require.NoError(t, err) assert.Equal(t, dir, cfg.projectDir) }) @@ -128,7 +129,7 @@ project: ` dir := setupConfigTestDir(t, content) - cfg, err := LoadConfig(dir) + cfg, err := LoadConfig(io.Local, dir) assert.Error(t, err) assert.Nil(t, cfg) assert.Contains(t, err.Error(), "failed to parse config file") @@ -145,7 +146,7 @@ project: err = os.Mkdir(configPath, 0755) require.NoError(t, err) - cfg, err := LoadConfig(dir) + cfg, err := LoadConfig(io.Local, dir) assert.Error(t, err) assert.Nil(t, cfg) assert.Contains(t, err.Error(), "failed to read config file") @@ -204,17 +205,17 @@ func TestConfigPath_Good(t *testing.T) { func TestConfigExists_Good(t *testing.T) { t.Run("returns true when config exists", func(t *testing.T) { dir := setupConfigTestDir(t, "version: 1") - assert.True(t, ConfigExists(dir)) + assert.True(t, ConfigExists(io.Local, dir)) }) t.Run("returns false when config missing", func(t *testing.T) { dir := t.TempDir() - assert.False(t, ConfigExists(dir)) + assert.False(t, ConfigExists(io.Local, dir)) }) t.Run("returns false when .core dir missing", func(t *testing.T) { dir := t.TempDir() - assert.False(t, ConfigExists(dir)) + assert.False(t, ConfigExists(io.Local, dir)) }) } @@ -226,14 +227,14 @@ func TestWriteConfig_Good(t *testing.T) { cfg.Project.Name = "testapp" cfg.Project.Repository = "owner/testapp" - err := WriteConfig(cfg, dir) + err := WriteConfig(io.Local, cfg, dir) require.NoError(t, err) // Verify file exists - assert.True(t, ConfigExists(dir)) + assert.True(t, ConfigExists(io.Local, dir)) // Reload and verify - loaded, err := LoadConfig(dir) + loaded, err := LoadConfig(io.Local, dir) require.NoError(t, err) assert.Equal(t, "testapp", loaded.Project.Name) assert.Equal(t, "owner/testapp", loaded.Project.Repository) @@ -243,7 +244,7 @@ func TestWriteConfig_Good(t *testing.T) { dir := t.TempDir() cfg := DefaultConfig() - err := WriteConfig(cfg, dir) + err := WriteConfig(io.Local, cfg, dir) require.NoError(t, err) // Check directory was created @@ -320,7 +321,7 @@ func TestWriteConfig_Bad(t *testing.T) { defer func() { _ = os.Chmod(coreDir, 0755) }() cfg := DefaultConfig() - err = WriteConfig(cfg, dir) + err = WriteConfig(io.Local, cfg, dir) assert.Error(t, err) assert.Contains(t, err.Error(), "failed to write config file") }) @@ -328,7 +329,7 @@ func TestWriteConfig_Bad(t *testing.T) { t.Run("returns error when directory creation fails", func(t *testing.T) { // Use a path that doesn't exist and can't be created cfg := DefaultConfig() - err := WriteConfig(cfg, "/nonexistent/path/that/cannot/be/created") + err := WriteConfig(io.Local, cfg, "/nonexistent/path/that/cannot/be/created") assert.Error(t, err) }) } diff --git a/pkg/release/publishers/aur.go b/pkg/release/publishers/aur.go index 00ad86ca..0f9cd2c4 100644 --- a/pkg/release/publishers/aur.go +++ b/pkg/release/publishers/aur.go @@ -13,6 +13,7 @@ import ( "text/template" "github.com/host-uk/core/pkg/build" + "github.com/host-uk/core/pkg/io" ) //go:embed templates/aur/*.tmpl @@ -90,10 +91,10 @@ func (p *AURPublisher) Publish(ctx context.Context, release *Release, pubCfg Pub } if dryRun { - return p.dryRunPublish(data, cfg) + return p.dryRunPublish(release.FS, data, cfg) } - return p.executePublish(ctx, release.ProjectDir, data, cfg) + return p.executePublish(ctx, release.ProjectDir, data, cfg, release) } type aurTemplateData struct { @@ -131,7 +132,7 @@ func (p *AURPublisher) parseConfig(pubCfg PublisherConfig, relCfg ReleaseConfig) return cfg } -func (p *AURPublisher) dryRunPublish(data aurTemplateData, cfg AURConfig) error { +func (p *AURPublisher) dryRunPublish(m io.Medium, data aurTemplateData, cfg AURConfig) error { fmt.Println() fmt.Println("=== DRY RUN: AUR Publish ===") fmt.Println() @@ -141,7 +142,7 @@ func (p *AURPublisher) dryRunPublish(data aurTemplateData, cfg AURConfig) error fmt.Printf("Repository: %s\n", data.Repository) fmt.Println() - pkgbuild, err := p.renderTemplate("templates/aur/PKGBUILD.tmpl", data) + pkgbuild, err := p.renderTemplate(m, "templates/aur/PKGBUILD.tmpl", data) if err != nil { return fmt.Errorf("aur.dryRunPublish: %w", err) } @@ -151,7 +152,7 @@ func (p *AURPublisher) dryRunPublish(data aurTemplateData, cfg AURConfig) error fmt.Println("---") fmt.Println() - srcinfo, err := p.renderTemplate("templates/aur/.SRCINFO.tmpl", data) + srcinfo, err := p.renderTemplate(m, "templates/aur/.SRCINFO.tmpl", data) if err != nil { return fmt.Errorf("aur.dryRunPublish: %w", err) } @@ -168,13 +169,13 @@ func (p *AURPublisher) dryRunPublish(data aurTemplateData, cfg AURConfig) error return nil } -func (p *AURPublisher) executePublish(ctx context.Context, projectDir string, data aurTemplateData, cfg AURConfig) error { - pkgbuild, err := p.renderTemplate("templates/aur/PKGBUILD.tmpl", data) +func (p *AURPublisher) executePublish(ctx context.Context, projectDir string, data aurTemplateData, cfg AURConfig, release *Release) error { + pkgbuild, err := p.renderTemplate(release.FS, "templates/aur/PKGBUILD.tmpl", data) if err != nil { return fmt.Errorf("aur.Publish: failed to render PKGBUILD: %w", err) } - srcinfo, err := p.renderTemplate("templates/aur/.SRCINFO.tmpl", data) + srcinfo, err := p.renderTemplate(release.FS, "templates/aur/.SRCINFO.tmpl", data) if err != nil { return fmt.Errorf("aur.Publish: failed to render .SRCINFO: %w", err) } @@ -188,17 +189,17 @@ func (p *AURPublisher) executePublish(ctx context.Context, projectDir string, da output = filepath.Join(projectDir, output) } - if err := os.MkdirAll(output, 0755); err != nil { + if err := release.FS.EnsureDir(output); err != nil { return fmt.Errorf("aur.Publish: failed to create output directory: %w", err) } pkgbuildPath := filepath.Join(output, "PKGBUILD") - if err := os.WriteFile(pkgbuildPath, []byte(pkgbuild), 0644); err != nil { + if err := release.FS.Write(pkgbuildPath, pkgbuild); err != nil { return fmt.Errorf("aur.Publish: failed to write PKGBUILD: %w", err) } srcinfoPath := filepath.Join(output, ".SRCINFO") - if err := os.WriteFile(srcinfoPath, []byte(srcinfo), 0644); err != nil { + if err := release.FS.Write(srcinfoPath, srcinfo); err != nil { return fmt.Errorf("aur.Publish: failed to write .SRCINFO: %w", err) } fmt.Printf("Wrote AUR files: %s\n", output) @@ -274,10 +275,25 @@ func (p *AURPublisher) pushToAUR(ctx context.Context, data aurTemplateData, pkgb return nil } -func (p *AURPublisher) renderTemplate(name string, data aurTemplateData) (string, error) { - content, err := aurTemplates.ReadFile(name) - if err != nil { - return "", fmt.Errorf("failed to read template %s: %w", name, err) +func (p *AURPublisher) renderTemplate(m io.Medium, name string, data aurTemplateData) (string, error) { + var content []byte + var err error + + // Try custom template from medium + customPath := filepath.Join(".core", name) + if m != nil && m.IsFile(customPath) { + customContent, err := m.Read(customPath) + if err == nil { + content = []byte(customContent) + } + } + + // Fallback to embedded template + if content == nil { + content, err = aurTemplates.ReadFile(name) + if err != nil { + return "", fmt.Errorf("failed to read template %s: %w", name, err) + } } tmpl, err := template.New(filepath.Base(name)).Parse(string(content)) diff --git a/pkg/release/publishers/aur_test.go b/pkg/release/publishers/aur_test.go index a49b68e1..3b0e6231 100644 --- a/pkg/release/publishers/aur_test.go +++ b/pkg/release/publishers/aur_test.go @@ -6,6 +6,7 @@ import ( "os" "testing" + "github.com/host-uk/core/pkg/io" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -97,7 +98,7 @@ func TestAURPublisher_RenderTemplate_Good(t *testing.T) { }, } - result, err := p.renderTemplate("templates/aur/PKGBUILD.tmpl", data) + result, err := p.renderTemplate(io.Local, "templates/aur/PKGBUILD.tmpl", data) require.NoError(t, err) assert.Contains(t, result, "# Maintainer: John Doe ") @@ -125,7 +126,7 @@ func TestAURPublisher_RenderTemplate_Good(t *testing.T) { }, } - result, err := p.renderTemplate("templates/aur/.SRCINFO.tmpl", data) + result, err := p.renderTemplate(io.Local, "templates/aur/.SRCINFO.tmpl", data) require.NoError(t, err) assert.Contains(t, result, "pkgbase = myapp-bin") @@ -144,7 +145,7 @@ func TestAURPublisher_RenderTemplate_Bad(t *testing.T) { t.Run("returns error for non-existent template", func(t *testing.T) { data := aurTemplateData{} - _, err := p.renderTemplate("templates/aur/nonexistent.tmpl", data) + _, err := p.renderTemplate(io.Local, "templates/aur/nonexistent.tmpl", data) assert.Error(t, err) assert.Contains(t, err.Error(), "failed to read template") }) @@ -170,7 +171,7 @@ func TestAURPublisher_DryRunPublish_Good(t *testing.T) { Maintainer: "John Doe ", } - err := p.dryRunPublish(data, cfg) + err := p.dryRunPublish(io.Local, data, cfg) _ = w.Close() var buf bytes.Buffer @@ -199,6 +200,7 @@ func TestAURPublisher_Publish_Bad(t *testing.T) { release := &Release{ Version: "v1.0.0", ProjectDir: "/project", + FS: io.Local, } pubCfg := PublisherConfig{Type: "aur"} relCfg := &mockReleaseConfig{repository: "owner/repo"} diff --git a/pkg/release/publishers/chocolatey.go b/pkg/release/publishers/chocolatey.go index 9c58d2d1..93b12160 100644 --- a/pkg/release/publishers/chocolatey.go +++ b/pkg/release/publishers/chocolatey.go @@ -14,6 +14,7 @@ import ( "github.com/host-uk/core/pkg/build" "github.com/host-uk/core/pkg/i18n" + "github.com/host-uk/core/pkg/io" ) //go:embed templates/chocolatey/*.tmpl templates/chocolatey/tools/*.tmpl @@ -92,10 +93,10 @@ func (p *ChocolateyPublisher) Publish(ctx context.Context, release *Release, pub } if dryRun { - return p.dryRunPublish(data, cfg) + return p.dryRunPublish(release.FS, data, cfg) } - return p.executePublish(ctx, release.ProjectDir, data, cfg) + return p.executePublish(ctx, release.ProjectDir, data, cfg, release) } type chocolateyTemplateData struct { @@ -137,7 +138,7 @@ func (p *ChocolateyPublisher) parseConfig(pubCfg PublisherConfig, relCfg Release return cfg } -func (p *ChocolateyPublisher) dryRunPublish(data chocolateyTemplateData, cfg ChocolateyConfig) error { +func (p *ChocolateyPublisher) dryRunPublish(m io.Medium, data chocolateyTemplateData, cfg ChocolateyConfig) error { fmt.Println() fmt.Println("=== DRY RUN: Chocolatey Publish ===") fmt.Println() @@ -147,7 +148,7 @@ func (p *ChocolateyPublisher) dryRunPublish(data chocolateyTemplateData, cfg Cho fmt.Printf("Repository: %s\n", data.Repository) fmt.Println() - nuspec, err := p.renderTemplate("templates/chocolatey/package.nuspec.tmpl", data) + nuspec, err := p.renderTemplate(m, "templates/chocolatey/package.nuspec.tmpl", data) if err != nil { return fmt.Errorf("chocolatey.dryRunPublish: %w", err) } @@ -157,7 +158,7 @@ func (p *ChocolateyPublisher) dryRunPublish(data chocolateyTemplateData, cfg Cho fmt.Println("---") fmt.Println() - install, err := p.renderTemplate("templates/chocolatey/tools/chocolateyinstall.ps1.tmpl", data) + install, err := p.renderTemplate(m, "templates/chocolatey/tools/chocolateyinstall.ps1.tmpl", data) if err != nil { return fmt.Errorf("chocolatey.dryRunPublish: %w", err) } @@ -178,13 +179,13 @@ func (p *ChocolateyPublisher) dryRunPublish(data chocolateyTemplateData, cfg Cho return nil } -func (p *ChocolateyPublisher) executePublish(ctx context.Context, projectDir string, data chocolateyTemplateData, cfg ChocolateyConfig) error { - nuspec, err := p.renderTemplate("templates/chocolatey/package.nuspec.tmpl", data) +func (p *ChocolateyPublisher) executePublish(ctx context.Context, projectDir string, data chocolateyTemplateData, cfg ChocolateyConfig, release *Release) error { + nuspec, err := p.renderTemplate(release.FS, "templates/chocolatey/package.nuspec.tmpl", data) if err != nil { return fmt.Errorf("chocolatey.Publish: failed to render nuspec: %w", err) } - install, err := p.renderTemplate("templates/chocolatey/tools/chocolateyinstall.ps1.tmpl", data) + install, err := p.renderTemplate(release.FS, "templates/chocolatey/tools/chocolateyinstall.ps1.tmpl", data) if err != nil { return fmt.Errorf("chocolatey.Publish: failed to render install script: %w", err) } @@ -199,18 +200,18 @@ func (p *ChocolateyPublisher) executePublish(ctx context.Context, projectDir str } toolsDir := filepath.Join(output, "tools") - if err := os.MkdirAll(toolsDir, 0755); err != nil { + if err := release.FS.EnsureDir(toolsDir); err != nil { return fmt.Errorf("chocolatey.Publish: failed to create output directory: %w", err) } // Write files nuspecPath := filepath.Join(output, fmt.Sprintf("%s.nuspec", data.PackageName)) - if err := os.WriteFile(nuspecPath, []byte(nuspec), 0644); err != nil { + if err := release.FS.Write(nuspecPath, nuspec); err != nil { return fmt.Errorf("chocolatey.Publish: failed to write nuspec: %w", err) } installPath := filepath.Join(toolsDir, "chocolateyinstall.ps1") - if err := os.WriteFile(installPath, []byte(install), 0644); err != nil { + if err := release.FS.Write(installPath, install); err != nil { return fmt.Errorf("chocolatey.Publish: failed to write install script: %w", err) } @@ -255,10 +256,25 @@ func (p *ChocolateyPublisher) pushToChocolatey(ctx context.Context, packageDir s return nil } -func (p *ChocolateyPublisher) renderTemplate(name string, data chocolateyTemplateData) (string, error) { - content, err := chocolateyTemplates.ReadFile(name) - if err != nil { - return "", fmt.Errorf("failed to read template %s: %w", name, err) +func (p *ChocolateyPublisher) renderTemplate(m io.Medium, name string, data chocolateyTemplateData) (string, error) { + var content []byte + var err error + + // Try custom template from medium + customPath := filepath.Join(".core", name) + if m != nil && m.IsFile(customPath) { + customContent, err := m.Read(customPath) + if err == nil { + content = []byte(customContent) + } + } + + // Fallback to embedded template + if content == nil { + content, err = chocolateyTemplates.ReadFile(name) + if err != nil { + return "", fmt.Errorf("failed to read template %s: %w", name, err) + } } tmpl, err := template.New(filepath.Base(name)).Parse(string(content)) diff --git a/pkg/release/publishers/chocolatey_test.go b/pkg/release/publishers/chocolatey_test.go index 3da669b1..df41aba4 100644 --- a/pkg/release/publishers/chocolatey_test.go +++ b/pkg/release/publishers/chocolatey_test.go @@ -6,6 +6,8 @@ import ( "os" "testing" + "github.com/host-uk/core/pkg/io" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -122,7 +124,7 @@ func TestChocolateyPublisher_RenderTemplate_Good(t *testing.T) { Checksums: ChecksumMap{}, } - result, err := p.renderTemplate("templates/chocolatey/package.nuspec.tmpl", data) + result, err := p.renderTemplate(io.Local, "templates/chocolatey/package.nuspec.tmpl", data) require.NoError(t, err) assert.Contains(t, result, `myapp`) @@ -146,7 +148,7 @@ func TestChocolateyPublisher_RenderTemplate_Good(t *testing.T) { }, } - result, err := p.renderTemplate("templates/chocolatey/tools/chocolateyinstall.ps1.tmpl", data) + result, err := p.renderTemplate(io.Local, "templates/chocolatey/tools/chocolateyinstall.ps1.tmpl", data) require.NoError(t, err) assert.Contains(t, result, "$ErrorActionPreference = 'Stop'") @@ -163,7 +165,7 @@ func TestChocolateyPublisher_RenderTemplate_Bad(t *testing.T) { t.Run("returns error for non-existent template", func(t *testing.T) { data := chocolateyTemplateData{} - _, err := p.renderTemplate("templates/chocolatey/nonexistent.tmpl", data) + _, err := p.renderTemplate(io.Local, "templates/chocolatey/nonexistent.tmpl", data) assert.Error(t, err) assert.Contains(t, err.Error(), "failed to read template") }) @@ -190,7 +192,7 @@ func TestChocolateyPublisher_DryRunPublish_Good(t *testing.T) { Push: false, } - err := p.dryRunPublish(data, cfg) + err := p.dryRunPublish(io.Local, data, cfg) _ = w.Close() var buf bytes.Buffer @@ -228,7 +230,7 @@ func TestChocolateyPublisher_DryRunPublish_Good(t *testing.T) { Push: true, } - err := p.dryRunPublish(data, cfg) + err := p.dryRunPublish(io.Local, data, cfg) _ = w.Close() var buf bytes.Buffer diff --git a/pkg/release/publishers/docker.go b/pkg/release/publishers/docker.go index 7d342ab3..981d4420 100644 --- a/pkg/release/publishers/docker.go +++ b/pkg/release/publishers/docker.go @@ -50,7 +50,7 @@ func (p *DockerPublisher) Publish(ctx context.Context, release *Release, pubCfg dockerCfg := p.parseConfig(pubCfg, relCfg, release.ProjectDir) // Validate Dockerfile exists - if _, err := os.Stat(dockerCfg.Dockerfile); err != nil { + if !release.FS.Exists(dockerCfg.Dockerfile) { return fmt.Errorf("docker.Publish: Dockerfile not found: %s", dockerCfg.Dockerfile) } diff --git a/pkg/release/publishers/docker_test.go b/pkg/release/publishers/docker_test.go index a36a5517..9673a274 100644 --- a/pkg/release/publishers/docker_test.go +++ b/pkg/release/publishers/docker_test.go @@ -7,6 +7,7 @@ import ( "path/filepath" "testing" + "github.com/host-uk/core/pkg/io" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -238,6 +239,7 @@ func TestDockerPublisher_Publish_Bad(t *testing.T) { release := &Release{ Version: "v1.0.0", ProjectDir: "/nonexistent", + FS: io.Local, } pubCfg := PublisherConfig{ Type: "docker", @@ -282,6 +284,7 @@ func TestDockerPublisher_DryRunPublish_Good(t *testing.T) { release := &Release{ Version: "v1.0.0", ProjectDir: "/project", + FS: io.Local, } cfg := DockerConfig{ Registry: "ghcr.io", @@ -324,6 +327,7 @@ func TestDockerPublisher_DryRunPublish_Good(t *testing.T) { release := &Release{ Version: "v1.0.0", ProjectDir: "/project", + FS: io.Local, } cfg := DockerConfig{ Registry: "docker.io", @@ -360,6 +364,7 @@ func TestDockerPublisher_DryRunPublish_Good(t *testing.T) { release := &Release{ Version: "v2.0.0", ProjectDir: "/project", + FS: io.Local, } cfg := DockerConfig{ Registry: "ghcr.io", @@ -583,6 +588,7 @@ func TestDockerPublisher_Publish_DryRun_Good(t *testing.T) { release := &Release{ Version: "v1.0.0", ProjectDir: tmpDir, + FS: io.Local, } pubCfg := PublisherConfig{Type: "docker"} relCfg := &mockReleaseConfig{repository: "owner/repo"} @@ -620,6 +626,7 @@ func TestDockerPublisher_Publish_DryRun_Good(t *testing.T) { release := &Release{ Version: "v1.0.0", ProjectDir: tmpDir, + FS: io.Local, } pubCfg := PublisherConfig{ Type: "docker", @@ -653,6 +660,7 @@ func TestDockerPublisher_Publish_Validation_Bad(t *testing.T) { release := &Release{ Version: "v1.0.0", ProjectDir: "/nonexistent/path", + FS: io.Local, } pubCfg := PublisherConfig{Type: "docker"} relCfg := &mockReleaseConfig{repository: "owner/repo"} @@ -670,6 +678,7 @@ func TestDockerPublisher_Publish_Validation_Bad(t *testing.T) { release := &Release{ Version: "v1.0.0", ProjectDir: "/tmp", + FS: io.Local, } pubCfg := PublisherConfig{Type: "docker"} relCfg := &mockReleaseConfig{repository: "owner/repo"} @@ -715,6 +724,7 @@ func TestDockerPublisher_Publish_WithCLI_Good(t *testing.T) { release := &Release{ Version: "v1.0.0", ProjectDir: tmpDir, + FS: io.Local, } pubCfg := PublisherConfig{ Type: "docker", @@ -758,6 +768,7 @@ func TestDockerPublisher_Publish_WithCLI_Good(t *testing.T) { release := &Release{ Version: "v1.0.0", ProjectDir: tmpDir, + FS: io.Local, } pubCfg := PublisherConfig{ Type: "docker", @@ -787,6 +798,7 @@ func TestDockerPublisher_Publish_WithCLI_Good(t *testing.T) { release := &Release{ Version: "v1.0.0", ProjectDir: tmpDir, + FS: io.Local, } pubCfg := PublisherConfig{Type: "docker"} relCfg := &mockReleaseConfig{repository: "owner/repo"} diff --git a/pkg/release/publishers/github_test.go b/pkg/release/publishers/github_test.go index 78af460f..7d89d053 100644 --- a/pkg/release/publishers/github_test.go +++ b/pkg/release/publishers/github_test.go @@ -9,6 +9,7 @@ import ( "testing" "github.com/host-uk/core/pkg/build" + "github.com/host-uk/core/pkg/io" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -90,7 +91,7 @@ func TestGitHubPublisher_Name_Good(t *testing.T) { func TestNewRelease_Good(t *testing.T) { t.Run("creates release struct", func(t *testing.T) { - r := NewRelease("v1.0.0", nil, "changelog", "/project") + r := NewRelease("v1.0.0", nil, "changelog", "/project", io.Local) assert.Equal(t, "v1.0.0", r.Version) assert.Equal(t, "changelog", r.Changelog) assert.Equal(t, "/project", r.ProjectDir) @@ -122,6 +123,7 @@ func TestBuildCreateArgs_Good(t *testing.T) { release := &Release{ Version: "v1.0.0", Changelog: "## v1.0.0\n\nChanges", + FS: io.Local, } cfg := PublisherConfig{ Type: "github", @@ -141,6 +143,7 @@ func TestBuildCreateArgs_Good(t *testing.T) { t.Run("with draft flag", func(t *testing.T) { release := &Release{ Version: "v1.0.0", + FS: io.Local, } cfg := PublisherConfig{ Type: "github", @@ -155,6 +158,7 @@ func TestBuildCreateArgs_Good(t *testing.T) { t.Run("with prerelease flag", func(t *testing.T) { release := &Release{ Version: "v1.0.0", + FS: io.Local, } cfg := PublisherConfig{ Type: "github", @@ -170,6 +174,7 @@ func TestBuildCreateArgs_Good(t *testing.T) { release := &Release{ Version: "v1.0.0", Changelog: "", + FS: io.Local, } cfg := PublisherConfig{ Type: "github", @@ -183,6 +188,7 @@ func TestBuildCreateArgs_Good(t *testing.T) { t.Run("with draft and prerelease flags", func(t *testing.T) { release := &Release{ Version: "v1.0.0-alpha", + FS: io.Local, } cfg := PublisherConfig{ Type: "github", @@ -200,6 +206,7 @@ func TestBuildCreateArgs_Good(t *testing.T) { release := &Release{ Version: "v2.0.0", Changelog: "Some changes", + FS: io.Local, } cfg := PublisherConfig{ Type: "github", @@ -226,6 +233,7 @@ func TestGitHubPublisher_DryRunPublish_Good(t *testing.T) { Version: "v1.0.0", Changelog: "## Changes\n\n- Feature A\n- Bug fix B", ProjectDir: "/project", + FS: io.Local, } cfg := PublisherConfig{ Type: "github", @@ -264,6 +272,7 @@ func TestGitHubPublisher_DryRunPublish_Good(t *testing.T) { Version: "v1.0.0", Changelog: "Changes", ProjectDir: "/project", + FS: io.Local, Artifacts: []build.Artifact{ {Path: "/dist/myapp-darwin-amd64.tar.gz"}, {Path: "/dist/myapp-linux-amd64.tar.gz"}, @@ -295,6 +304,7 @@ func TestGitHubPublisher_DryRunPublish_Good(t *testing.T) { Version: "v1.0.0-beta", Changelog: "Beta release", ProjectDir: "/project", + FS: io.Local, } cfg := PublisherConfig{ Type: "github", @@ -331,6 +341,7 @@ func TestGitHubPublisher_Publish_Good(t *testing.T) { Version: "v1.0.0", Changelog: "Changes", ProjectDir: "/tmp", + FS: io.Local, } pubCfg := PublisherConfig{Type: "github"} relCfg := &mockReleaseConfig{repository: "custom/repo"} @@ -363,6 +374,7 @@ func TestGitHubPublisher_Publish_Bad(t *testing.T) { Version: "v1.0.0", Changelog: "Changes", ProjectDir: "/nonexistent", + FS: io.Local, } pubCfg := PublisherConfig{Type: "github"} relCfg := &mockReleaseConfig{repository: "owner/repo"} @@ -383,6 +395,7 @@ func TestGitHubPublisher_Publish_Bad(t *testing.T) { Version: "v1.0.0", Changelog: "Changes", ProjectDir: tmpDir, + FS: io.Local, } pubCfg := PublisherConfig{Type: "github"} relCfg := &mockReleaseConfig{repository: ""} // Empty repository @@ -504,6 +517,7 @@ func TestGitHubPublisher_ExecutePublish_Good(t *testing.T) { Version: "v999.999.999-test-nonexistent", Changelog: "Test changelog", ProjectDir: "/tmp", + FS: io.Local, Artifacts: []build.Artifact{ {Path: "/tmp/nonexistent-artifact.tar.gz"}, }, diff --git a/pkg/release/publishers/homebrew.go b/pkg/release/publishers/homebrew.go index 00b9abb0..10fc3d7d 100644 --- a/pkg/release/publishers/homebrew.go +++ b/pkg/release/publishers/homebrew.go @@ -13,6 +13,7 @@ import ( "text/template" "github.com/host-uk/core/pkg/build" + "github.com/host-uk/core/pkg/io" ) //go:embed templates/homebrew/*.tmpl @@ -104,10 +105,10 @@ func (p *HomebrewPublisher) Publish(ctx context.Context, release *Release, pubCf } if dryRun { - return p.dryRunPublish(data, cfg) + return p.dryRunPublish(release.FS, data, cfg) } - return p.executePublish(ctx, release.ProjectDir, data, cfg) + return p.executePublish(ctx, release.ProjectDir, data, cfg, release) } // homebrewTemplateData holds data for Homebrew templates. @@ -160,7 +161,7 @@ func (p *HomebrewPublisher) parseConfig(pubCfg PublisherConfig, relCfg ReleaseCo } // dryRunPublish shows what would be done. -func (p *HomebrewPublisher) dryRunPublish(data homebrewTemplateData, cfg HomebrewConfig) error { +func (p *HomebrewPublisher) dryRunPublish(m io.Medium, data homebrewTemplateData, cfg HomebrewConfig) error { fmt.Println() fmt.Println("=== DRY RUN: Homebrew Publish ===") fmt.Println() @@ -171,7 +172,7 @@ func (p *HomebrewPublisher) dryRunPublish(data homebrewTemplateData, cfg Homebre fmt.Println() // Generate and show formula - formula, err := p.renderTemplate("templates/homebrew/formula.rb.tmpl", data) + formula, err := p.renderTemplate(m, "templates/homebrew/formula.rb.tmpl", data) if err != nil { return fmt.Errorf("homebrew.dryRunPublish: %w", err) } @@ -198,9 +199,9 @@ func (p *HomebrewPublisher) dryRunPublish(data homebrewTemplateData, cfg Homebre } // executePublish creates the formula and commits to tap. -func (p *HomebrewPublisher) executePublish(ctx context.Context, projectDir string, data homebrewTemplateData, cfg HomebrewConfig) error { +func (p *HomebrewPublisher) executePublish(ctx context.Context, projectDir string, data homebrewTemplateData, cfg HomebrewConfig, release *Release) error { // Generate formula - formula, err := p.renderTemplate("templates/homebrew/formula.rb.tmpl", data) + formula, err := p.renderTemplate(release.FS, "templates/homebrew/formula.rb.tmpl", data) if err != nil { return fmt.Errorf("homebrew.Publish: failed to render formula: %w", err) } @@ -214,12 +215,12 @@ func (p *HomebrewPublisher) executePublish(ctx context.Context, projectDir strin output = filepath.Join(projectDir, output) } - if err := os.MkdirAll(output, 0755); err != nil { + if err := release.FS.EnsureDir(output); err != nil { return fmt.Errorf("homebrew.Publish: failed to create output directory: %w", err) } formulaPath := filepath.Join(output, fmt.Sprintf("%s.rb", strings.ToLower(data.FormulaClass))) - if err := os.WriteFile(formulaPath, []byte(formula), 0644); err != nil { + if err := release.FS.Write(formulaPath, formula); err != nil { return fmt.Errorf("homebrew.Publish: failed to write formula: %w", err) } fmt.Printf("Wrote Homebrew formula for official PR: %s\n", formulaPath) @@ -295,10 +296,25 @@ func (p *HomebrewPublisher) commitToTap(ctx context.Context, tap string, data ho } // renderTemplate renders an embedded template with the given data. -func (p *HomebrewPublisher) renderTemplate(name string, data homebrewTemplateData) (string, error) { - content, err := homebrewTemplates.ReadFile(name) - if err != nil { - return "", fmt.Errorf("failed to read template %s: %w", name, err) +func (p *HomebrewPublisher) renderTemplate(m io.Medium, name string, data homebrewTemplateData) (string, error) { + var content []byte + var err error + + // Try custom template from medium + customPath := filepath.Join(".core", name) + if m != nil && m.IsFile(customPath) { + customContent, err := m.Read(customPath) + if err == nil { + content = []byte(customContent) + } + } + + // Fallback to embedded template + if content == nil { + content, err = homebrewTemplates.ReadFile(name) + if err != nil { + return "", fmt.Errorf("failed to read template %s: %w", name, err) + } } tmpl, err := template.New(filepath.Base(name)).Parse(string(content)) diff --git a/pkg/release/publishers/homebrew_test.go b/pkg/release/publishers/homebrew_test.go index d9e0c112..e05f24e2 100644 --- a/pkg/release/publishers/homebrew_test.go +++ b/pkg/release/publishers/homebrew_test.go @@ -7,6 +7,7 @@ import ( "testing" "github.com/host-uk/core/pkg/build" + "github.com/host-uk/core/pkg/io" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -185,7 +186,7 @@ func TestHomebrewPublisher_RenderTemplate_Good(t *testing.T) { }, } - result, err := p.renderTemplate("templates/homebrew/formula.rb.tmpl", data) + result, err := p.renderTemplate(io.Local, "templates/homebrew/formula.rb.tmpl", data) require.NoError(t, err) assert.Contains(t, result, "class MyApp < Formula") @@ -206,7 +207,7 @@ func TestHomebrewPublisher_RenderTemplate_Bad(t *testing.T) { t.Run("returns error for non-existent template", func(t *testing.T) { data := homebrewTemplateData{} - _, err := p.renderTemplate("templates/homebrew/nonexistent.tmpl", data) + _, err := p.renderTemplate(io.Local, "templates/homebrew/nonexistent.tmpl", data) assert.Error(t, err) assert.Contains(t, err.Error(), "failed to read template") }) @@ -234,7 +235,7 @@ func TestHomebrewPublisher_DryRunPublish_Good(t *testing.T) { Tap: "owner/homebrew-tap", } - err := p.dryRunPublish(data, cfg) + err := p.dryRunPublish(io.Local, data, cfg) _ = w.Close() var buf bytes.Buffer @@ -271,7 +272,7 @@ func TestHomebrewPublisher_DryRunPublish_Good(t *testing.T) { }, } - err := p.dryRunPublish(data, cfg) + err := p.dryRunPublish(io.Local, data, cfg) _ = w.Close() var buf bytes.Buffer @@ -300,7 +301,7 @@ func TestHomebrewPublisher_DryRunPublish_Good(t *testing.T) { }, } - err := p.dryRunPublish(data, cfg) + err := p.dryRunPublish(io.Local, data, cfg) _ = w.Close() var buf bytes.Buffer @@ -320,6 +321,7 @@ func TestHomebrewPublisher_Publish_Bad(t *testing.T) { release := &Release{ Version: "v1.0.0", ProjectDir: "/project", + FS: io.Local, } pubCfg := PublisherConfig{Type: "homebrew"} relCfg := &mockReleaseConfig{repository: "owner/repo"} diff --git a/pkg/release/publishers/linuxkit.go b/pkg/release/publishers/linuxkit.go index 2a5ca828..4905575d 100644 --- a/pkg/release/publishers/linuxkit.go +++ b/pkg/release/publishers/linuxkit.go @@ -47,7 +47,7 @@ func (p *LinuxKitPublisher) Publish(ctx context.Context, release *Release, pubCf lkCfg := p.parseConfig(pubCfg, release.ProjectDir) // Validate config file exists - if _, err := os.Stat(lkCfg.Config); err != nil { + if !release.FS.Exists(lkCfg.Config) { return fmt.Errorf("linuxkit.Publish: config file not found: %s", lkCfg.Config) } @@ -169,7 +169,7 @@ func (p *LinuxKitPublisher) executePublish(ctx context.Context, release *Release outputDir := filepath.Join(release.ProjectDir, "dist", "linuxkit") // Create output directory - if err := os.MkdirAll(outputDir, 0755); err != nil { + if err := release.FS.EnsureDir(outputDir); err != nil { return fmt.Errorf("linuxkit.Publish: failed to create output directory: %w", err) } @@ -207,7 +207,7 @@ func (p *LinuxKitPublisher) executePublish(ctx context.Context, release *Release // Upload artifacts to GitHub release for _, artifactPath := range artifacts { - if _, err := os.Stat(artifactPath); err != nil { + if !release.FS.Exists(artifactPath) { return fmt.Errorf("linuxkit.Publish: artifact not found after build: %s", artifactPath) } diff --git a/pkg/release/publishers/linuxkit_test.go b/pkg/release/publishers/linuxkit_test.go index 361d1fa3..7def1da4 100644 --- a/pkg/release/publishers/linuxkit_test.go +++ b/pkg/release/publishers/linuxkit_test.go @@ -8,6 +8,7 @@ import ( "path/filepath" "testing" + "github.com/host-uk/core/pkg/io" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -192,6 +193,7 @@ func TestLinuxKitPublisher_Publish_Bad(t *testing.T) { release := &Release{ Version: "v1.0.0", ProjectDir: "/nonexistent", + FS: io.Local, } pubCfg := PublisherConfig{ Type: "linuxkit", @@ -214,6 +216,7 @@ func TestLinuxKitPublisher_Publish_Bad(t *testing.T) { release := &Release{ Version: "v1.0.0", ProjectDir: "/tmp", + FS: io.Local, } pubCfg := PublisherConfig{Type: "linuxkit"} relCfg := &mockReleaseConfig{repository: "owner/repo"} @@ -241,6 +244,7 @@ func TestLinuxKitPublisher_Publish_Bad(t *testing.T) { release := &Release{ Version: "v1.0.0", ProjectDir: tmpDir, + FS: io.Local, } pubCfg := PublisherConfig{ Type: "linuxkit", @@ -296,6 +300,7 @@ func TestLinuxKitPublisher_Publish_WithCLI_Good(t *testing.T) { release := &Release{ Version: "v1.0.0", ProjectDir: tmpDir, + FS: io.Local, } pubCfg := PublisherConfig{Type: "linuxkit"} relCfg := &mockReleaseConfig{repository: "owner/repo"} @@ -320,6 +325,7 @@ func TestLinuxKitPublisher_Publish_WithCLI_Good(t *testing.T) { release := &Release{ Version: "v1.0.0", ProjectDir: tmpDir, + FS: io.Local, } pubCfg := PublisherConfig{Type: "linuxkit"} relCfg := &mockReleaseConfig{repository: "owner/repo"} @@ -349,6 +355,7 @@ func TestLinuxKitPublisher_Publish_WithCLI_Good(t *testing.T) { release := &Release{ Version: "v1.0.0", ProjectDir: tmpDir, + FS: io.Local, } pubCfg := PublisherConfig{Type: "linuxkit"} relCfg := &mockReleaseConfig{repository: "custom-owner/custom-repo"} @@ -395,6 +402,7 @@ func TestLinuxKitPublisher_Publish_WithCLI_Good(t *testing.T) { release := &Release{ Version: "v1.0.0", ProjectDir: tmpDir, + FS: io.Local, } pubCfg := PublisherConfig{Type: "linuxkit"} relCfg := &mockReleaseConfig{repository: ""} // Empty to trigger detection @@ -490,6 +498,7 @@ func TestLinuxKitPublisher_DryRunPublish_Good(t *testing.T) { release := &Release{ Version: "v1.0.0", ProjectDir: "/project", + FS: io.Local, } cfg := LinuxKitConfig{ Config: "/project/.core/linuxkit/server.yml", @@ -531,6 +540,7 @@ func TestLinuxKitPublisher_DryRunPublish_Good(t *testing.T) { release := &Release{ Version: "v1.0.0", ProjectDir: "/project", + FS: io.Local, } cfg := LinuxKitConfig{ Config: "/config.yml", @@ -560,6 +570,7 @@ func TestLinuxKitPublisher_DryRunPublish_Good(t *testing.T) { release := &Release{ Version: "v2.0.0", ProjectDir: "/project", + FS: io.Local, } cfg := LinuxKitConfig{ Config: "/config.yml", @@ -823,6 +834,7 @@ func TestLinuxKitPublisher_Publish_DryRun_Good(t *testing.T) { release := &Release{ Version: "v1.0.0", ProjectDir: tmpDir, + FS: io.Local, } pubCfg := PublisherConfig{Type: "linuxkit"} relCfg := &mockReleaseConfig{repository: "owner/repo"} @@ -855,6 +867,7 @@ func TestLinuxKitPublisher_Publish_DryRun_Good(t *testing.T) { release := &Release{ Version: "v1.0.0", ProjectDir: tmpDir, + FS: io.Local, } pubCfg := PublisherConfig{ Type: "linuxkit", @@ -892,6 +905,7 @@ func TestLinuxKitPublisher_Publish_DryRun_Good(t *testing.T) { release := &Release{ Version: "v2.0.0", ProjectDir: tmpDir, + FS: io.Local, } pubCfg := PublisherConfig{ Type: "linuxkit", diff --git a/pkg/release/publishers/npm.go b/pkg/release/publishers/npm.go index 314b8e02..85df9283 100644 --- a/pkg/release/publishers/npm.go +++ b/pkg/release/publishers/npm.go @@ -11,6 +11,8 @@ import ( "path/filepath" "strings" "text/template" + + "github.com/host-uk/core/pkg/io" ) //go:embed templates/npm/*.tmpl @@ -88,10 +90,10 @@ func (p *NpmPublisher) Publish(ctx context.Context, release *Release, pubCfg Pub } if dryRun { - return p.dryRunPublish(data, &npmCfg) + return p.dryRunPublish(release.FS, data, &npmCfg) } - return p.executePublish(ctx, data, &npmCfg) + return p.executePublish(ctx, release.FS, data, &npmCfg) } // parseConfig extracts npm-specific configuration from the publisher config. @@ -127,7 +129,7 @@ type npmTemplateData struct { } // dryRunPublish shows what would be done without actually publishing. -func (p *NpmPublisher) dryRunPublish(data npmTemplateData, cfg *NpmConfig) error { +func (p *NpmPublisher) dryRunPublish(m io.Medium, data npmTemplateData, cfg *NpmConfig) error { fmt.Println() fmt.Println("=== DRY RUN: npm Publish ===") fmt.Println() @@ -139,7 +141,7 @@ func (p *NpmPublisher) dryRunPublish(data npmTemplateData, cfg *NpmConfig) error fmt.Println() // Generate and show package.json - pkgJSON, err := p.renderTemplate("templates/npm/package.json.tmpl", data) + pkgJSON, err := p.renderTemplate(m, "templates/npm/package.json.tmpl", data) if err != nil { return fmt.Errorf("npm.dryRunPublish: %w", err) } @@ -157,7 +159,7 @@ func (p *NpmPublisher) dryRunPublish(data npmTemplateData, cfg *NpmConfig) error } // executePublish actually creates and publishes the npm package. -func (p *NpmPublisher) executePublish(ctx context.Context, data npmTemplateData, cfg *NpmConfig) error { +func (p *NpmPublisher) executePublish(ctx context.Context, m io.Medium, data npmTemplateData, cfg *NpmConfig) error { // Check for NPM_TOKEN if os.Getenv("NPM_TOKEN") == "" { return fmt.Errorf("npm.Publish: NPM_TOKEN environment variable is required") @@ -177,7 +179,7 @@ func (p *NpmPublisher) executePublish(ctx context.Context, data npmTemplateData, } // Generate package.json - pkgJSON, err := p.renderTemplate("templates/npm/package.json.tmpl", data) + pkgJSON, err := p.renderTemplate(m, "templates/npm/package.json.tmpl", data) if err != nil { return fmt.Errorf("npm.Publish: failed to render package.json: %w", err) } @@ -186,7 +188,7 @@ func (p *NpmPublisher) executePublish(ctx context.Context, data npmTemplateData, } // Generate install.js - installJS, err := p.renderTemplate("templates/npm/install.js.tmpl", data) + installJS, err := p.renderTemplate(m, "templates/npm/install.js.tmpl", data) if err != nil { return fmt.Errorf("npm.Publish: failed to render install.js: %w", err) } @@ -195,7 +197,7 @@ func (p *NpmPublisher) executePublish(ctx context.Context, data npmTemplateData, } // Generate run.js - runJS, err := p.renderTemplate("templates/npm/run.js.tmpl", data) + runJS, err := p.renderTemplate(m, "templates/npm/run.js.tmpl", data) if err != nil { return fmt.Errorf("npm.Publish: failed to render run.js: %w", err) } @@ -228,10 +230,25 @@ func (p *NpmPublisher) executePublish(ctx context.Context, data npmTemplateData, } // renderTemplate renders an embedded template with the given data. -func (p *NpmPublisher) renderTemplate(name string, data npmTemplateData) (string, error) { - content, err := npmTemplates.ReadFile(name) - if err != nil { - return "", fmt.Errorf("failed to read template %s: %w", name, err) +func (p *NpmPublisher) renderTemplate(m io.Medium, name string, data npmTemplateData) (string, error) { + var content []byte + var err error + + // Try custom template from medium + customPath := filepath.Join(".core", name) + if m != nil && m.IsFile(customPath) { + customContent, err := m.Read(customPath) + if err == nil { + content = []byte(customContent) + } + } + + // Fallback to embedded template + if content == nil { + content, err = npmTemplates.ReadFile(name) + if err != nil { + return "", fmt.Errorf("failed to read template %s: %w", name, err) + } } tmpl, err := template.New(filepath.Base(name)).Parse(string(content)) diff --git a/pkg/release/publishers/npm_test.go b/pkg/release/publishers/npm_test.go index 29ffbcf2..6122788c 100644 --- a/pkg/release/publishers/npm_test.go +++ b/pkg/release/publishers/npm_test.go @@ -6,6 +6,8 @@ import ( "os" "testing" + "github.com/host-uk/core/pkg/io" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -101,7 +103,7 @@ func TestNpmPublisher_RenderTemplate_Good(t *testing.T) { Access: "public", } - result, err := p.renderTemplate("templates/npm/package.json.tmpl", data) + result, err := p.renderTemplate(io.Local, "templates/npm/package.json.tmpl", data) require.NoError(t, err) assert.Contains(t, result, `"name": "@myorg/mycli"`) @@ -125,7 +127,7 @@ func TestNpmPublisher_RenderTemplate_Good(t *testing.T) { Access: "restricted", } - result, err := p.renderTemplate("templates/npm/package.json.tmpl", data) + result, err := p.renderTemplate(io.Local, "templates/npm/package.json.tmpl", data) require.NoError(t, err) assert.Contains(t, result, `"access": "restricted"`) @@ -137,7 +139,7 @@ func TestNpmPublisher_RenderTemplate_Bad(t *testing.T) { t.Run("returns error for non-existent template", func(t *testing.T) { data := npmTemplateData{} - _, err := p.renderTemplate("templates/npm/nonexistent.tmpl", data) + _, err := p.renderTemplate(io.Local, "templates/npm/nonexistent.tmpl", data) assert.Error(t, err) assert.Contains(t, err.Error(), "failed to read template") }) @@ -164,7 +166,7 @@ func TestNpmPublisher_DryRunPublish_Good(t *testing.T) { Access: "public", } - err := p.dryRunPublish(data, cfg) + err := p.dryRunPublish(io.Local, data, cfg) _ = w.Close() var buf bytes.Buffer @@ -202,7 +204,7 @@ func TestNpmPublisher_DryRunPublish_Good(t *testing.T) { Access: "restricted", } - err := p.dryRunPublish(data, cfg) + err := p.dryRunPublish(io.Local, data, cfg) _ = w.Close() var buf bytes.Buffer @@ -224,6 +226,7 @@ func TestNpmPublisher_Publish_Bad(t *testing.T) { release := &Release{ Version: "v1.0.0", ProjectDir: "/project", + FS: io.Local, } pubCfg := PublisherConfig{Type: "npm"} relCfg := &mockReleaseConfig{repository: "owner/repo"} @@ -246,6 +249,7 @@ func TestNpmPublisher_Publish_Bad(t *testing.T) { release := &Release{ Version: "v1.0.0", ProjectDir: "/project", + FS: io.Local, } pubCfg := PublisherConfig{ Type: "npm", diff --git a/pkg/release/publishers/publisher.go b/pkg/release/publishers/publisher.go index f91de234..99e45f69 100644 --- a/pkg/release/publishers/publisher.go +++ b/pkg/release/publishers/publisher.go @@ -5,6 +5,7 @@ import ( "context" "github.com/host-uk/core/pkg/build" + "github.com/host-uk/core/pkg/io" ) // Release represents a release to be published. @@ -17,6 +18,8 @@ type Release struct { Changelog string // ProjectDir is the root directory of the project. ProjectDir string + // FS is the medium for file operations. + FS io.Medium } // PublisherConfig holds configuration for a publisher. @@ -48,12 +51,13 @@ type Publisher interface { // NewRelease creates a Release from the release package's Release type. // This is a helper to convert between packages. -func NewRelease(version string, artifacts []build.Artifact, changelog, projectDir string) *Release { +func NewRelease(version string, artifacts []build.Artifact, changelog, projectDir string, fs io.Medium) *Release { return &Release{ Version: version, Artifacts: artifacts, Changelog: changelog, ProjectDir: projectDir, + FS: fs, } } diff --git a/pkg/release/publishers/scoop.go b/pkg/release/publishers/scoop.go index 190fa78a..d0a46d7b 100644 --- a/pkg/release/publishers/scoop.go +++ b/pkg/release/publishers/scoop.go @@ -13,6 +13,7 @@ import ( "text/template" "github.com/host-uk/core/pkg/build" + "github.com/host-uk/core/pkg/io" ) //go:embed templates/scoop/*.tmpl @@ -82,10 +83,10 @@ func (p *ScoopPublisher) Publish(ctx context.Context, release *Release, pubCfg P } if dryRun { - return p.dryRunPublish(data, cfg) + return p.dryRunPublish(release.FS, data, cfg) } - return p.executePublish(ctx, release.ProjectDir, data, cfg) + return p.executePublish(ctx, release.ProjectDir, data, cfg, release) } type scoopTemplateData struct { @@ -119,7 +120,7 @@ func (p *ScoopPublisher) parseConfig(pubCfg PublisherConfig, relCfg ReleaseConfi return cfg } -func (p *ScoopPublisher) dryRunPublish(data scoopTemplateData, cfg ScoopConfig) error { +func (p *ScoopPublisher) dryRunPublish(m io.Medium, data scoopTemplateData, cfg ScoopConfig) error { fmt.Println() fmt.Println("=== DRY RUN: Scoop Publish ===") fmt.Println() @@ -129,7 +130,7 @@ func (p *ScoopPublisher) dryRunPublish(data scoopTemplateData, cfg ScoopConfig) fmt.Printf("Repository: %s\n", data.Repository) fmt.Println() - manifest, err := p.renderTemplate("templates/scoop/manifest.json.tmpl", data) + manifest, err := p.renderTemplate(m, "templates/scoop/manifest.json.tmpl", data) if err != nil { return fmt.Errorf("scoop.dryRunPublish: %w", err) } @@ -155,8 +156,8 @@ func (p *ScoopPublisher) dryRunPublish(data scoopTemplateData, cfg ScoopConfig) return nil } -func (p *ScoopPublisher) executePublish(ctx context.Context, projectDir string, data scoopTemplateData, cfg ScoopConfig) error { - manifest, err := p.renderTemplate("templates/scoop/manifest.json.tmpl", data) +func (p *ScoopPublisher) executePublish(ctx context.Context, projectDir string, data scoopTemplateData, cfg ScoopConfig, release *Release) error { + manifest, err := p.renderTemplate(release.FS, "templates/scoop/manifest.json.tmpl", data) if err != nil { return fmt.Errorf("scoop.Publish: failed to render manifest: %w", err) } @@ -170,12 +171,12 @@ func (p *ScoopPublisher) executePublish(ctx context.Context, projectDir string, output = filepath.Join(projectDir, output) } - if err := os.MkdirAll(output, 0755); err != nil { + if err := release.FS.EnsureDir(output); err != nil { return fmt.Errorf("scoop.Publish: failed to create output directory: %w", err) } manifestPath := filepath.Join(output, fmt.Sprintf("%s.json", data.PackageName)) - if err := os.WriteFile(manifestPath, []byte(manifest), 0644); err != nil { + if err := release.FS.Write(manifestPath, manifest); err != nil { return fmt.Errorf("scoop.Publish: failed to write manifest: %w", err) } fmt.Printf("Wrote Scoop manifest for official PR: %s\n", manifestPath) @@ -245,10 +246,25 @@ func (p *ScoopPublisher) commitToBucket(ctx context.Context, bucket string, data return nil } -func (p *ScoopPublisher) renderTemplate(name string, data scoopTemplateData) (string, error) { - content, err := scoopTemplates.ReadFile(name) - if err != nil { - return "", fmt.Errorf("failed to read template %s: %w", name, err) +func (p *ScoopPublisher) renderTemplate(m io.Medium, name string, data scoopTemplateData) (string, error) { + var content []byte + var err error + + // Try custom template from medium + customPath := filepath.Join(".core", name) + if m != nil && m.IsFile(customPath) { + customContent, err := m.Read(customPath) + if err == nil { + content = []byte(customContent) + } + } + + // Fallback to embedded template + if content == nil { + content, err = scoopTemplates.ReadFile(name) + if err != nil { + return "", fmt.Errorf("failed to read template %s: %w", name, err) + } } tmpl, err := template.New(filepath.Base(name)).Parse(string(content)) diff --git a/pkg/release/publishers/scoop_test.go b/pkg/release/publishers/scoop_test.go index ef84b20d..3dc6e780 100644 --- a/pkg/release/publishers/scoop_test.go +++ b/pkg/release/publishers/scoop_test.go @@ -6,6 +6,8 @@ import ( "os" "testing" + "github.com/host-uk/core/pkg/io" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -105,7 +107,7 @@ func TestScoopPublisher_RenderTemplate_Good(t *testing.T) { }, } - result, err := p.renderTemplate("templates/scoop/manifest.json.tmpl", data) + result, err := p.renderTemplate(io.Local, "templates/scoop/manifest.json.tmpl", data) require.NoError(t, err) assert.Contains(t, result, `"version": "1.2.3"`) @@ -132,7 +134,7 @@ func TestScoopPublisher_RenderTemplate_Good(t *testing.T) { Checksums: ChecksumMap{}, } - result, err := p.renderTemplate("templates/scoop/manifest.json.tmpl", data) + result, err := p.renderTemplate(io.Local, "templates/scoop/manifest.json.tmpl", data) require.NoError(t, err) assert.Contains(t, result, `"checkver"`) @@ -146,7 +148,7 @@ func TestScoopPublisher_RenderTemplate_Bad(t *testing.T) { t.Run("returns error for non-existent template", func(t *testing.T) { data := scoopTemplateData{} - _, err := p.renderTemplate("templates/scoop/nonexistent.tmpl", data) + _, err := p.renderTemplate(io.Local, "templates/scoop/nonexistent.tmpl", data) assert.Error(t, err) assert.Contains(t, err.Error(), "failed to read template") }) @@ -171,7 +173,7 @@ func TestScoopPublisher_DryRunPublish_Good(t *testing.T) { Bucket: "owner/scoop-bucket", } - err := p.dryRunPublish(data, cfg) + err := p.dryRunPublish(io.Local, data, cfg) _ = w.Close() var buf bytes.Buffer @@ -209,7 +211,7 @@ func TestScoopPublisher_DryRunPublish_Good(t *testing.T) { }, } - err := p.dryRunPublish(data, cfg) + err := p.dryRunPublish(io.Local, data, cfg) _ = w.Close() var buf bytes.Buffer @@ -238,7 +240,7 @@ func TestScoopPublisher_DryRunPublish_Good(t *testing.T) { }, } - err := p.dryRunPublish(data, cfg) + err := p.dryRunPublish(io.Local, data, cfg) _ = w.Close() var buf bytes.Buffer @@ -258,6 +260,7 @@ func TestScoopPublisher_Publish_Bad(t *testing.T) { release := &Release{ Version: "v1.0.0", ProjectDir: "/project", + FS: io.Local, } pubCfg := PublisherConfig{Type: "scoop"} relCfg := &mockReleaseConfig{repository: "owner/repo"} diff --git a/pkg/release/release.go b/pkg/release/release.go index 65e17f08..7237ffd8 100644 --- a/pkg/release/release.go +++ b/pkg/release/release.go @@ -25,6 +25,8 @@ type Release struct { Changelog string // ProjectDir is the root directory of the project. ProjectDir string + // FS is the medium for file operations. + FS io.Medium } // Publish publishes pre-built artifacts from dist/ to configured targets. @@ -35,6 +37,8 @@ func Publish(ctx context.Context, cfg *Config, dryRun bool) (*Release, error) { return nil, fmt.Errorf("release.Publish: config is nil") } + m := io.Local + projectDir := cfg.projectDir if projectDir == "" { projectDir = "." @@ -57,7 +61,7 @@ func Publish(ctx context.Context, cfg *Config, dryRun bool) (*Release, error) { // Step 2: Find pre-built artifacts in dist/ distDir := filepath.Join(absProjectDir, "dist") - artifacts, err := findArtifacts(distDir) + artifacts, err := findArtifacts(m, distDir) if err != nil { return nil, fmt.Errorf("release.Publish: %w", err) } @@ -78,11 +82,12 @@ func Publish(ctx context.Context, cfg *Config, dryRun bool) (*Release, error) { Artifacts: artifacts, Changelog: changelog, ProjectDir: absProjectDir, + FS: m, } // Step 4: Publish to configured targets if len(cfg.Publishers) > 0 { - pubRelease := publishers.NewRelease(release.Version, release.Artifacts, release.Changelog, release.ProjectDir) + pubRelease := publishers.NewRelease(release.Version, release.Artifacts, release.Changelog, release.ProjectDir, release.FS) for _, pubCfg := range cfg.Publishers { publisher, err := getPublisher(pubCfg.Type) @@ -102,14 +107,14 @@ func Publish(ctx context.Context, cfg *Config, dryRun bool) (*Release, error) { } // findArtifacts discovers pre-built artifacts in the dist directory. -func findArtifacts(distDir string) ([]build.Artifact, error) { - if !io.Local.IsDir(distDir) { +func findArtifacts(m io.Medium, distDir string) ([]build.Artifact, error) { + if !m.IsDir(distDir) { return nil, fmt.Errorf("dist/ directory not found") } var artifacts []build.Artifact - entries, err := io.Local.List(distDir) + entries, err := m.List(distDir) if err != nil { return nil, fmt.Errorf("failed to read dist/: %w", err) } @@ -143,6 +148,8 @@ func Run(ctx context.Context, cfg *Config, dryRun bool) (*Release, error) { return nil, fmt.Errorf("release.Run: config is nil") } + m := io.Local + projectDir := cfg.projectDir if projectDir == "" { projectDir = "." @@ -171,7 +178,7 @@ func Run(ctx context.Context, cfg *Config, dryRun bool) (*Release, error) { } // Step 3: Build artifacts - artifacts, err := buildArtifacts(ctx, cfg, absProjectDir, version) + artifacts, err := buildArtifacts(ctx, m, cfg, absProjectDir, version) if err != nil { return nil, fmt.Errorf("release.Run: build failed: %w", err) } @@ -181,12 +188,13 @@ func Run(ctx context.Context, cfg *Config, dryRun bool) (*Release, error) { Artifacts: artifacts, Changelog: changelog, ProjectDir: absProjectDir, + FS: m, } // Step 4: Publish to configured targets if len(cfg.Publishers) > 0 { // Convert to publisher types - pubRelease := publishers.NewRelease(release.Version, release.Artifacts, release.Changelog, release.ProjectDir) + pubRelease := publishers.NewRelease(release.Version, release.Artifacts, release.Changelog, release.ProjectDir, release.FS) for _, pubCfg := range cfg.Publishers { publisher, err := getPublisher(pubCfg.Type) @@ -207,10 +215,7 @@ func Run(ctx context.Context, cfg *Config, dryRun bool) (*Release, error) { } // buildArtifacts builds all artifacts for the release. -func buildArtifacts(ctx context.Context, cfg *Config, projectDir, version string) ([]build.Artifact, error) { - // Use local filesystem as the default medium - fs := io.Local - +func buildArtifacts(ctx context.Context, fs io.Medium, cfg *Config, projectDir, version string) ([]build.Artifact, error) { // Load build configuration buildCfg, err := build.LoadConfig(fs, projectDir) if err != nil { diff --git a/pkg/release/release_test.go b/pkg/release/release_test.go index 4eb3ac5c..d768e929 100644 --- a/pkg/release/release_test.go +++ b/pkg/release/release_test.go @@ -8,6 +8,7 @@ import ( "testing" "github.com/host-uk/core/pkg/build" + "github.com/host-uk/core/pkg/io" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -22,7 +23,7 @@ func TestFindArtifacts_Good(t *testing.T) { require.NoError(t, os.WriteFile(filepath.Join(distDir, "app-linux-amd64.tar.gz"), []byte("test"), 0644)) require.NoError(t, os.WriteFile(filepath.Join(distDir, "app-darwin-arm64.tar.gz"), []byte("test"), 0644)) - artifacts, err := findArtifacts(distDir) + artifacts, err := findArtifacts(io.Local, distDir) require.NoError(t, err) assert.Len(t, artifacts, 2) @@ -35,7 +36,7 @@ func TestFindArtifacts_Good(t *testing.T) { require.NoError(t, os.WriteFile(filepath.Join(distDir, "app-windows-amd64.zip"), []byte("test"), 0644)) - artifacts, err := findArtifacts(distDir) + artifacts, err := findArtifacts(io.Local, distDir) require.NoError(t, err) assert.Len(t, artifacts, 1) @@ -49,7 +50,7 @@ func TestFindArtifacts_Good(t *testing.T) { require.NoError(t, os.WriteFile(filepath.Join(distDir, "CHECKSUMS.txt"), []byte("checksums"), 0644)) - artifacts, err := findArtifacts(distDir) + artifacts, err := findArtifacts(io.Local, distDir) require.NoError(t, err) assert.Len(t, artifacts, 1) @@ -63,7 +64,7 @@ func TestFindArtifacts_Good(t *testing.T) { require.NoError(t, os.WriteFile(filepath.Join(distDir, "app.tar.gz.sig"), []byte("signature"), 0644)) - artifacts, err := findArtifacts(distDir) + artifacts, err := findArtifacts(io.Local, distDir) require.NoError(t, err) assert.Len(t, artifacts, 1) @@ -79,7 +80,7 @@ func TestFindArtifacts_Good(t *testing.T) { require.NoError(t, os.WriteFile(filepath.Join(distDir, "CHECKSUMS.txt"), []byte("checksums"), 0644)) require.NoError(t, os.WriteFile(filepath.Join(distDir, "app.sig"), []byte("sig"), 0644)) - artifacts, err := findArtifacts(distDir) + artifacts, err := findArtifacts(io.Local, distDir) require.NoError(t, err) assert.Len(t, artifacts, 4) @@ -94,7 +95,7 @@ func TestFindArtifacts_Good(t *testing.T) { require.NoError(t, os.WriteFile(filepath.Join(distDir, "app.exe"), []byte("binary"), 0644)) require.NoError(t, os.WriteFile(filepath.Join(distDir, "app.tar.gz"), []byte("artifact"), 0644)) - artifacts, err := findArtifacts(distDir) + artifacts, err := findArtifacts(io.Local, distDir) require.NoError(t, err) assert.Len(t, artifacts, 1) @@ -110,7 +111,7 @@ func TestFindArtifacts_Good(t *testing.T) { require.NoError(t, os.WriteFile(filepath.Join(distDir, "app.tar.gz"), []byte("artifact"), 0644)) require.NoError(t, os.WriteFile(filepath.Join(distDir, "subdir", "nested.tar.gz"), []byte("nested"), 0644)) - artifacts, err := findArtifacts(distDir) + artifacts, err := findArtifacts(io.Local, distDir) require.NoError(t, err) // Should only find the top-level artifact @@ -122,7 +123,7 @@ func TestFindArtifacts_Good(t *testing.T) { distDir := filepath.Join(dir, "dist") require.NoError(t, os.MkdirAll(distDir, 0755)) - artifacts, err := findArtifacts(distDir) + artifacts, err := findArtifacts(io.Local, distDir) require.NoError(t, err) assert.Empty(t, artifacts) @@ -134,7 +135,7 @@ func TestFindArtifacts_Bad(t *testing.T) { dir := t.TempDir() distDir := filepath.Join(dir, "dist") - _, err := findArtifacts(distDir) + _, err := findArtifacts(io.Local, distDir) assert.Error(t, err) assert.Contains(t, err.Error(), "dist/ directory not found") }) @@ -149,7 +150,7 @@ func TestFindArtifacts_Bad(t *testing.T) { require.NoError(t, os.Chmod(distDir, 0000)) defer func() { _ = os.Chmod(distDir, 0755) }() - _, err := findArtifacts(distDir) + _, err := findArtifacts(io.Local, distDir) assert.Error(t, err) assert.Contains(t, err.Error(), "failed to read dist/") })