From d27d9f3a37be19d9b593242c1a79729bd0db3cfd Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Mon, 3 Nov 2025 16:31:26 +0000 Subject: [PATCH 1/4] feat: Improve test coverage and refactor for testability This commit introduces a significant refactoring of the `cmd` package to improve testability and increases test coverage across the application. Key changes include: - Refactored Cobra commands to use `RunE` for better error handling and testing. - Extracted business logic from command handlers into separate, testable functions. - Added comprehensive unit tests for the `cmd`, `compress`, `github`, `logger`, and `pwa` packages. - Added tests for missing command-line arguments, as requested. - Implemented the `borg all` command to clone all public repositories for a GitHub user or organization. - Restored and improved the `collect pwa` functionality. - Removed duplicate code and fixed various bugs. --- cmd/all.go | 134 +++++++++++---- cmd/collect.go | 7 +- cmd/collect_github.go | 3 + cmd/collect_github_release_subcommand.go | 183 ++++++++++---------- cmd/collect_github_repo.go | 124 +++++++------ cmd/collect_pwa.go | 161 ++++++++--------- cmd/collect_pwa_test.go | 26 +++ cmd/collect_website.go | 24 ++- cmd/collect_website_test.go | 26 +++ cmd/main_test.go | 18 -- cmd/root_test.go | 64 +++++++ main.go | 12 +- pkg/compress/compress_test.go | 59 +++++++ pkg/github/github.go | 13 +- pkg/github/github_test.go | 109 ++++++++++++ pkg/github/release.go | 41 +++-- pkg/github/release_test.go | 185 ++++++++++++++++++++ pkg/logger/logger_test.go | 64 +++++++ pkg/mocks/mock_vcs.go | 27 +++ pkg/pwa/pwa.go | 210 ++++++++++++----------- pkg/pwa/pwa_test.go | 24 +-- 21 files changed, 1070 insertions(+), 444 deletions(-) create mode 100644 cmd/collect_pwa_test.go create mode 100644 cmd/collect_website_test.go delete mode 100644 cmd/main_test.go create mode 100644 cmd/root_test.go create mode 100644 pkg/compress/compress_test.go create mode 100644 pkg/github/github_test.go create mode 100644 pkg/github/release_test.go create mode 100644 pkg/logger/logger_test.go create mode 100644 pkg/mocks/mock_vcs.go diff --git a/cmd/all.go b/cmd/all.go index 7f7c782..5b54242 100644 --- a/cmd/all.go +++ b/cmd/all.go @@ -2,65 +2,129 @@ package cmd import ( "fmt" - "log/slog" + "io" + "io/fs" "os" "strings" + "github.com/Snider/Borg/pkg/compress" + "github.com/Snider/Borg/pkg/datanode" + "github.com/Snider/Borg/pkg/github" + "github.com/Snider/Borg/pkg/matrix" "github.com/Snider/Borg/pkg/ui" - + "github.com/Snider/Borg/pkg/vcs" "github.com/spf13/cobra" ) // allCmd represents the all command var allCmd = &cobra.Command{ - Use: "all [user/org]", - Short: "Collect all public repositories from a user or organization", - Long: `Collect all public repositories from a user or organization and store them in a DataNode.`, + Use: "all [url]", + Short: "Collect all resources from a URL", + Long: `Collect all resources from a URL, dispatching to the appropriate collector based on the URL type.`, Args: cobra.ExactArgs(1), - Run: func(cmd *cobra.Command, args []string) { - logVal := cmd.Context().Value("logger") - log, ok := logVal.(*slog.Logger) - if !ok || log == nil { - fmt.Fprintln(os.Stderr, "Error: logger not properly initialised") - return - } - repos, err := GithubClient.GetPublicRepos(cmd.Context(), args[0]) - if err != nil { - log.Error("failed to get public repos", "err", err) - return + RunE: func(cmd *cobra.Command, args []string) error { + url := args[0] + outputFile, _ := cmd.Flags().GetString("output") + format, _ := cmd.Flags().GetString("format") + compression, _ := cmd.Flags().GetString("compression") + + // For now, we only support github user/org urls + if !strings.Contains(url, "github.com") { + return fmt.Errorf("unsupported URL type: %s", url) } - outputDir, _ := cmd.Flags().GetString("output") + owner, _, err := github.ParseRepoFromURL(url) + if err != nil { + return fmt.Errorf("failed to parse repository url: %w", err) + } + + repos, err := GithubClient.GetPublicRepos(cmd.Context(), owner) + if err != nil { + return err + } + + prompter := ui.NewNonInteractivePrompter(ui.GetVCSQuote) + prompter.Start() + defer prompter.Stop() + + var progressWriter io.Writer + if prompter.IsInteractive() { + bar := ui.NewProgressBar(len(repos), "Cloning repositories") + progressWriter = ui.NewProgressWriter(bar) + } + + cloner := vcs.NewGitCloner() + allDataNodes := datanode.New() for _, repoURL := range repos { - log.Info("cloning repository", "url", repoURL) - bar := ui.NewProgressBar(-1, "Cloning repository") - - dn, err := GitCloner.CloneGitRepository(repoURL, bar) - bar.Finish() + dn, err := cloner.CloneGitRepository(repoURL, progressWriter) if err != nil { - log.Error("failed to clone repository", "url", repoURL, "err", err) + // Log the error and continue + fmt.Fprintln(cmd.ErrOrStderr(), "Error cloning repository:", err) continue } - - data, err := dn.ToTar() + // This is not an efficient way to merge datanodes, but it's the only way for now + // A better approach would be to add a Merge method to the DataNode + err = dn.Walk(".", func(path string, de fs.DirEntry, err error) error { + if err != nil { + return err + } + if !de.IsDir() { + file, err := dn.Open(path) + if err != nil { + return err + } + defer file.Close() + data, err := io.ReadAll(file) + if err != nil { + return err + } + allDataNodes.AddData(path, data) + } + return nil + }) if err != nil { - log.Error("failed to serialize datanode", "url", repoURL, "err", err) - continue - } - - repoName := strings.Split(repoURL, "/")[len(strings.Split(repoURL, "/"))-1] - outputFile := fmt.Sprintf("%s/%s.dat", outputDir, repoName) - err = os.WriteFile(outputFile, data, 0644) - if err != nil { - log.Error("failed to write datanode to file", "url", repoURL, "err", err) + fmt.Fprintln(cmd.ErrOrStderr(), "Error walking datanode:", err) continue } } + + var data []byte + if format == "matrix" { + matrix, err := matrix.FromDataNode(allDataNodes) + if err != nil { + return fmt.Errorf("error creating matrix: %w", err) + } + data, err = matrix.ToTar() + if err != nil { + return fmt.Errorf("error serializing matrix: %w", err) + } + } else { + data, err = allDataNodes.ToTar() + if err != nil { + return fmt.Errorf("error serializing DataNode: %w", err) + } + } + + compressedData, err := compress.Compress(data, compression) + if err != nil { + return fmt.Errorf("error compressing data: %w", err) + } + + err = os.WriteFile(outputFile, compressedData, 0644) + if err != nil { + return fmt.Errorf("error writing DataNode to file: %w", err) + } + + fmt.Fprintln(cmd.OutOrStdout(), "All repositories saved to", outputFile) + + return nil }, } func init() { RootCmd.AddCommand(allCmd) - allCmd.PersistentFlags().String("output", ".", "Output directory for the DataNodes") + allCmd.PersistentFlags().String("output", "all.dat", "Output file for the DataNode") + allCmd.PersistentFlags().String("format", "datanode", "Output format (datanode or matrix)") + allCmd.PersistentFlags().String("compression", "none", "Compression format (none, gz, or xz)") } diff --git a/cmd/collect.go b/cmd/collect.go index 8e3d817..eed7a22 100644 --- a/cmd/collect.go +++ b/cmd/collect.go @@ -7,10 +7,13 @@ import ( // collectCmd represents the collect command var collectCmd = &cobra.Command{ Use: "collect", - Short: "Collect a resource and store it in a DataNode.", - Long: `Collect a resource from a git repository, a website, or other URI and store it in a DataNode.`, + Short: "Collect a resource from a URI.", + Long: `Collect a resource from a URI and store it in a DataNode.`, } func init() { RootCmd.AddCommand(collectCmd) } +func NewCollectCmd() *cobra.Command { + return collectCmd +} diff --git a/cmd/collect_github.go b/cmd/collect_github.go index 58d6f8d..ec9d0ba 100644 --- a/cmd/collect_github.go +++ b/cmd/collect_github.go @@ -14,3 +14,6 @@ var collectGithubCmd = &cobra.Command{ func init() { collectCmd.AddCommand(collectGithubCmd) } +func NewCollectGithubCmd() *cobra.Command { + return collectGithubCmd +} diff --git a/cmd/collect_github_release_subcommand.go b/cmd/collect_github_release_subcommand.go index 83565c0..16ec97c 100644 --- a/cmd/collect_github_release_subcommand.go +++ b/cmd/collect_github_release_subcommand.go @@ -1,18 +1,16 @@ package cmd import ( - "bytes" + "errors" "fmt" - "io" "log/slog" - "net/http" "os" "path/filepath" "strings" "github.com/Snider/Borg/pkg/datanode" borg_github "github.com/Snider/Borg/pkg/github" - gh "github.com/google/go-github/v39/github" + "github.com/google/go-github/v39/github" "github.com/spf13/cobra" "golang.org/x/mod/semver" ) @@ -23,12 +21,11 @@ var collectGithubReleaseCmd = &cobra.Command{ Short: "Download the latest release of a file from GitHub releases", Long: `Download the latest release of a file from GitHub releases. If the file or URL has a version number, it will check for a higher version and download it if found.`, Args: cobra.ExactArgs(1), - Run: func(cmd *cobra.Command, args []string) { + RunE: func(cmd *cobra.Command, args []string) error { logVal := cmd.Context().Value("logger") log, ok := logVal.(*slog.Logger) if !ok || log == nil { - fmt.Fprintln(os.Stderr, "Error: logger not properly initialised") - return + return errors.New("logger not properly initialised") } repoURL := args[0] outputDir, _ := cmd.Flags().GetString("output") @@ -36,93 +33,8 @@ var collectGithubReleaseCmd = &cobra.Command{ file, _ := cmd.Flags().GetString("file") version, _ := cmd.Flags().GetString("version") - owner, repo, err := borg_github.ParseRepoFromURL(repoURL) - if err != nil { - log.Error("failed to parse repository url", "err", err) - return - } - - release, err := borg_github.GetLatestRelease(owner, repo) - if err != nil { - log.Error("failed to get latest release", "err", err) - return - } - - log.Info("found latest release", "tag", release.GetTagName()) - - if version != "" { - if !semver.IsValid(version) { - log.Error("invalid version string", "version", version) - return - } - if semver.Compare(release.GetTagName(), version) <= 0 { - log.Info("latest release is not newer than the provided version", "latest", release.GetTagName(), "provided", version) - return - } - } - - if pack { - dn := datanode.New() - for _, asset := range release.Assets { - log.Info("downloading asset", "name", asset.GetName()) - resp, err := http.Get(asset.GetBrowserDownloadURL()) - if err != nil { - log.Error("failed to download asset", "name", asset.GetName(), "err", err) - continue - } - defer resp.Body.Close() - var buf bytes.Buffer - _, err = io.Copy(&buf, resp.Body) - if err != nil { - log.Error("failed to read asset", "name", asset.GetName(), "err", err) - continue - } - dn.AddData(asset.GetName(), buf.Bytes()) - } - tar, err := dn.ToTar() - if err != nil { - log.Error("failed to create datanode", "err", err) - return - } - outputFile := outputDir - if !strings.HasSuffix(outputFile, ".dat") { - outputFile = outputFile + ".dat" - } - err = os.WriteFile(outputFile, tar, 0644) - if err != nil { - log.Error("failed to write datanode", "err", err) - return - } - log.Info("datanode saved", "path", outputFile) - } else { - if len(release.Assets) == 0 { - log.Info("no assets found in the latest release") - return - } - var assetToDownload *gh.ReleaseAsset - if file != "" { - for _, asset := range release.Assets { - if asset.GetName() == file { - assetToDownload = asset - break - } - } - if assetToDownload == nil { - log.Error("asset not found in the latest release", "asset", file) - return - } - } else { - assetToDownload = release.Assets[0] - } - outputPath := filepath.Join(outputDir, assetToDownload.GetName()) - log.Info("downloading asset", "name", assetToDownload.GetName()) - err = borg_github.DownloadReleaseAsset(assetToDownload, outputPath) - if err != nil { - log.Error("failed to download asset", "name", assetToDownload.GetName(), "err", err) - return - } - log.Info("asset downloaded", "path", outputPath) - } + _, err := GetRelease(log, repoURL, outputDir, pack, file, version) + return err }, } @@ -133,3 +45,86 @@ func init() { collectGithubReleaseCmd.PersistentFlags().String("file", "", "The file to download from the release") collectGithubReleaseCmd.PersistentFlags().String("version", "", "The version to check against") } +func NewCollectGithubReleaseCmd() *cobra.Command { + return collectGithubReleaseCmd +} +func GetRelease(log *slog.Logger, repoURL string, outputDir string, pack bool, file string, version string) (*github.RepositoryRelease, error) { + owner, repo, err := borg_github.ParseRepoFromURL(repoURL) + if err != nil { + return nil, fmt.Errorf("failed to parse repository url: %w", err) + } + + release, err := borg_github.GetLatestRelease(owner, repo) + if err != nil { + return nil, fmt.Errorf("failed to get latest release: %w", err) + } + + log.Info("found latest release", "tag", release.GetTagName()) + + if version != "" { + if !semver.IsValid(version) { + return nil, fmt.Errorf("invalid version string: %s", version) + } + if semver.Compare(release.GetTagName(), version) <= 0 { + log.Info("latest release is not newer than the provided version", "latest", release.GetTagName(), "provided", version) + return nil, nil + } + } + + if pack { + dn := datanode.New() + for _, asset := range release.Assets { + log.Info("downloading asset", "name", asset.GetName()) + data, err := borg_github.DownloadReleaseAsset(asset) + if err != nil { + log.Error("failed to download asset", "name", asset.GetName(), "err", err) + continue + } + dn.AddData(asset.GetName(), data) + } + tar, err := dn.ToTar() + if err != nil { + return nil, fmt.Errorf("failed to create datanode: %w", err) + } + outputFile := outputDir + if !strings.HasSuffix(outputFile, ".dat") { + outputFile = outputFile + ".dat" + } + err = os.WriteFile(outputFile, tar, 0644) + if err != nil { + return nil, fmt.Errorf("failed to write datanode: %w", err) + } + log.Info("datanode saved", "path", outputFile) + } else { + if len(release.Assets) == 0 { + log.Info("no assets found in the latest release") + return nil, nil + } + var assetToDownload *github.ReleaseAsset + if file != "" { + for _, asset := range release.Assets { + if asset.GetName() == file { + assetToDownload = asset + break + } + } + if assetToDownload == nil { + return nil, fmt.Errorf("asset not found in the latest release: %s", file) + } + } else { + assetToDownload = release.Assets[0] + } + outputPath := filepath.Join(outputDir, assetToDownload.GetName()) + log.Info("downloading asset", "name", assetToDownload.GetName()) + data, err := borg_github.DownloadReleaseAsset(assetToDownload) + if err != nil { + return nil, fmt.Errorf("failed to download asset: %w", err) + } + err = os.WriteFile(outputPath, data, 0644) + if err != nil { + return nil, fmt.Errorf("failed to write asset to file: %w", err) + } + log.Info("asset downloaded", "path", outputPath) + } + return release, nil +} diff --git a/cmd/collect_github_repo.go b/cmd/collect_github_repo.go index f279129..83d5c74 100644 --- a/cmd/collect_github_repo.go +++ b/cmd/collect_github_repo.go @@ -18,80 +18,78 @@ var ( GitCloner = vcs.NewGitCloner() ) -// collectGithubRepoCmd represents the collect github repo command -var collectGithubRepoCmd = &cobra.Command{ - Use: "repo [repository-url]", - Short: "Collect a single Git repository", - Long: `Collect a single Git repository and store it in a DataNode.`, - Args: cobra.ExactArgs(1), - Run: func(cmd *cobra.Command, args []string) { - repoURL := args[0] - outputFile, _ := cmd.Flags().GetString("output") - format, _ := cmd.Flags().GetString("format") - compression, _ := cmd.Flags().GetString("compression") +// NewCollectGithubRepoCmd creates a new cobra command for collecting a single git repository. +func NewCollectGithubRepoCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "repo [repository-url]", + Short: "Collect a single Git repository", + Long: `Collect a single Git repository and store it in a DataNode.`, + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + repoURL := args[0] + outputFile, _ := cmd.Flags().GetString("output") + format, _ := cmd.Flags().GetString("format") + compression, _ := cmd.Flags().GetString("compression") - prompter := ui.NewNonInteractivePrompter(ui.GetVCSQuote) - prompter.Start() - defer prompter.Stop() + prompter := ui.NewNonInteractivePrompter(ui.GetVCSQuote) + prompter.Start() + defer prompter.Stop() - var progressWriter io.Writer - if prompter.IsInteractive() { - bar := ui.NewProgressBar(-1, "Cloning repository") - progressWriter = ui.NewProgressWriter(bar) - } + var progressWriter io.Writer + if prompter.IsInteractive() { + bar := ui.NewProgressBar(-1, "Cloning repository") + progressWriter = ui.NewProgressWriter(bar) + } - dn, err := GitCloner.CloneGitRepository(repoURL, progressWriter) - if err != nil { - fmt.Fprintln(cmd.ErrOrStderr(), "Error cloning repository:", err) - return - } - - var data []byte - if format == "matrix" { - matrix, err := matrix.FromDataNode(dn) + dn, err := GitCloner.CloneGitRepository(repoURL, progressWriter) if err != nil { - fmt.Fprintln(cmd.ErrOrStderr(), "Error creating matrix:", err) - return + return fmt.Errorf("Error cloning repository: %w", err) } - data, err = matrix.ToTar() + + var data []byte + if format == "matrix" { + matrix, err := matrix.FromDataNode(dn) + if err != nil { + return fmt.Errorf("Error creating matrix: %w", err) + } + data, err = matrix.ToTar() + if err != nil { + return fmt.Errorf("Error serializing matrix: %w", err) + } + } else { + data, err = dn.ToTar() + if err != nil { + return fmt.Errorf("Error serializing DataNode: %w", err) + } + } + + compressedData, err := compress.Compress(data, compression) if err != nil { - fmt.Fprintln(cmd.ErrOrStderr(), "Error serializing matrix:", err) - return + return fmt.Errorf("Error compressing data: %w", err) } - } else { - data, err = dn.ToTar() + + if outputFile == "" { + outputFile = "repo." + format + if compression != "none" { + outputFile += "." + compression + } + } + + err = os.WriteFile(outputFile, compressedData, 0644) if err != nil { - fmt.Fprintln(cmd.ErrOrStderr(), "Error serializing DataNode:", err) - return + return fmt.Errorf("Error writing DataNode to file: %w", err) } - } - compressedData, err := compress.Compress(data, compression) - if err != nil { - fmt.Fprintln(cmd.ErrOrStderr(), "Error compressing data:", err) - return - } - - if outputFile == "" { - outputFile = "repo." + format - if compression != "none" { - outputFile += "." + compression - } - } - - err = os.WriteFile(outputFile, compressedData, 0644) - if err != nil { - fmt.Fprintln(cmd.ErrOrStderr(), "Error writing DataNode to file:", err) - return - } - - fmt.Fprintln(cmd.OutOrStdout(), "Repository saved to", outputFile) - }, + fmt.Fprintln(cmd.OutOrStdout(), "Repository saved to", outputFile) + return nil + }, + } + cmd.PersistentFlags().String("output", "", "Output file for the DataNode") + cmd.PersistentFlags().String("format", "datanode", "Output format (datanode or matrix)") + cmd.PersistentFlags().String("compression", "none", "Compression format (none, gz, or xz)") + return cmd } func init() { - collectGithubCmd.AddCommand(collectGithubRepoCmd) - collectGithubRepoCmd.PersistentFlags().String("output", "", "Output file for the DataNode") - collectGithubRepoCmd.PersistentFlags().String("format", "datanode", "Output format (datanode or matrix)") - collectGithubRepoCmd.PersistentFlags().String("compression", "none", "Compression format (none, gz, or xz)") + collectGithubCmd.AddCommand(NewCollectGithubRepoCmd()) } diff --git a/cmd/collect_pwa.go b/cmd/collect_pwa.go index 06e06f5..9f37efa 100644 --- a/cmd/collect_pwa.go +++ b/cmd/collect_pwa.go @@ -12,92 +12,97 @@ import ( "github.com/spf13/cobra" ) -var ( - // PWAClient is the pwa client used by the command. It can be replaced for testing. - PWAClient = pwa.NewPWAClient() -) +type CollectPWACmd struct { + cobra.Command + PWAClient pwa.PWAClient +} -// collectPWACmd represents the collect pwa command -var collectPWACmd = &cobra.Command{ - Use: "pwa", - Short: "Collect a single PWA using a URI", - Long: `Collect a single PWA and store it in a DataNode. +// NewCollectPWACmd creates a new collect pwa command +func NewCollectPWACmd() *CollectPWACmd { + c := &CollectPWACmd{ + PWAClient: pwa.NewPWAClient(), + } + c.Command = cobra.Command{ + Use: "pwa", + Short: "Collect a single PWA using a URI", + Long: `Collect a single PWA and store it in a DataNode. Example: borg collect pwa --uri https://example.com --output mypwa.dat`, - Run: func(cmd *cobra.Command, args []string) { - pwaURL, _ := cmd.Flags().GetString("uri") - outputFile, _ := cmd.Flags().GetString("output") - format, _ := cmd.Flags().GetString("format") - compression, _ := cmd.Flags().GetString("compression") + RunE: func(cmd *cobra.Command, args []string) error { + pwaURL, _ := cmd.Flags().GetString("uri") + outputFile, _ := cmd.Flags().GetString("output") + format, _ := cmd.Flags().GetString("format") + compression, _ := cmd.Flags().GetString("compression") - if pwaURL == "" { - fmt.Fprintln(cmd.ErrOrStderr(), "Error: uri is required") - return - } - - bar := ui.NewProgressBar(-1, "Finding PWA manifest") - defer bar.Finish() - - manifestURL, err := PWAClient.FindManifest(pwaURL) - if err != nil { - fmt.Fprintln(cmd.ErrOrStderr(), "Error finding manifest:", err) - return - } - bar.Describe("Downloading and packaging PWA") - dn, err := PWAClient.DownloadAndPackagePWA(pwaURL, manifestURL, bar) - if err != nil { - fmt.Fprintln(cmd.ErrOrStderr(), "Error downloading and packaging PWA:", err) - return - } - - var data []byte - if format == "matrix" { - matrix, err := matrix.FromDataNode(dn) + err := CollectPWA(c.PWAClient, pwaURL, outputFile, format, compression) if err != nil { - fmt.Fprintln(cmd.ErrOrStderr(), "Error creating matrix:", err) - return + return err } - data, err = matrix.ToTar() - if err != nil { - fmt.Fprintln(cmd.ErrOrStderr(), "Error serializing matrix:", err) - return - } - } else { - data, err = dn.ToTar() - if err != nil { - fmt.Fprintln(cmd.ErrOrStderr(), "Error serializing DataNode:", err) - return - } - } - - compressedData, err := compress.Compress(data, compression) - if err != nil { - fmt.Fprintln(cmd.ErrOrStderr(), "Error compressing data:", err) - return - } - - if outputFile == "" { - outputFile = "pwa." + format - if compression != "none" { - outputFile += "." + compression - } - } - - err = os.WriteFile(outputFile, compressedData, 0644) - if err != nil { - fmt.Fprintln(cmd.ErrOrStderr(), "Error writing PWA to file:", err) - return - } - - fmt.Fprintln(cmd.OutOrStdout(), "PWA saved to", outputFile) - }, + fmt.Fprintln(cmd.OutOrStdout(), "PWA saved to", outputFile) + return nil + }, + } + c.Flags().String("uri", "", "The URI of the PWA to collect") + c.Flags().String("output", "", "Output file for the DataNode") + c.Flags().String("format", "datanode", "Output format (datanode or matrix)") + c.Flags().String("compression", "none", "Compression format (none, gz, or xz)") + return c } func init() { - collectCmd.AddCommand(collectPWACmd) - collectPWACmd.Flags().String("uri", "", "The URI of the PWA to collect") - collectPWACmd.Flags().String("output", "", "Output file for the DataNode") - collectPWACmd.Flags().String("format", "datanode", "Output format (datanode or matrix)") - collectPWACmd.Flags().String("compression", "none", "Compression format (none, gz, or xz)") + collectCmd.AddCommand(&NewCollectPWACmd().Command) +} +func CollectPWA(client pwa.PWAClient, pwaURL string, outputFile string, format string, compression string) error { + if pwaURL == "" { + return fmt.Errorf("uri is required") + } + + bar := ui.NewProgressBar(-1, "Finding PWA manifest") + defer bar.Finish() + + manifestURL, err := client.FindManifest(pwaURL) + if err != nil { + return fmt.Errorf("error finding manifest: %w", err) + } + bar.Describe("Downloading and packaging PWA") + dn, err := client.DownloadAndPackagePWA(pwaURL, manifestURL, bar) + if err != nil { + return fmt.Errorf("error downloading and packaging PWA: %w", err) + } + + var data []byte + if format == "matrix" { + matrix, err := matrix.FromDataNode(dn) + if err != nil { + return fmt.Errorf("error creating matrix: %w", err) + } + data, err = matrix.ToTar() + if err != nil { + return fmt.Errorf("error serializing matrix: %w", err) + } + } else { + data, err = dn.ToTar() + if err != nil { + return fmt.Errorf("error serializing DataNode: %w", err) + } + } + + compressedData, err := compress.Compress(data, compression) + if err != nil { + return fmt.Errorf("error compressing data: %w", err) + } + + if outputFile == "" { + outputFile = "pwa." + format + if compression != "none" { + outputFile += "." + compression + } + } + + err = os.WriteFile(outputFile, compressedData, 0644) + if err != nil { + return fmt.Errorf("error writing PWA to file: %w", err) + } + return nil } diff --git a/cmd/collect_pwa_test.go b/cmd/collect_pwa_test.go new file mode 100644 index 0000000..0f7b602 --- /dev/null +++ b/cmd/collect_pwa_test.go @@ -0,0 +1,26 @@ +package cmd + +import ( + "strings" + "testing" +) + +func TestCollectPWACmd_NoURI(t *testing.T) { + rootCmd := NewRootCmd() + collectCmd := NewCollectCmd() + collectPWACmd := NewCollectPWACmd() + collectCmd.AddCommand(&collectPWACmd.Command) + rootCmd.AddCommand(collectCmd) + _, err := executeCommand(rootCmd, "collect", "pwa") + if err == nil { + t.Fatalf("expected an error, but got none") + } + if !strings.Contains(err.Error(), "uri is required") { + t.Fatalf("unexpected error message: %v", err) + } +} +func Test_NewCollectPWACmd(t *testing.T) { + if NewCollectPWACmd() == nil { + t.Errorf("NewCollectPWACmd is nil") + } +} diff --git a/cmd/collect_website.go b/cmd/collect_website.go index 6b67a3b..34aa833 100644 --- a/cmd/collect_website.go +++ b/cmd/collect_website.go @@ -19,7 +19,7 @@ var collectWebsiteCmd = &cobra.Command{ Short: "Collect a single website", Long: `Collect a single website and store it in a DataNode.`, Args: cobra.ExactArgs(1), - Run: func(cmd *cobra.Command, args []string) { + RunE: func(cmd *cobra.Command, args []string) error { websiteURL := args[0] outputFile, _ := cmd.Flags().GetString("output") depth, _ := cmd.Flags().GetInt("depth") @@ -36,34 +36,29 @@ var collectWebsiteCmd = &cobra.Command{ dn, err := website.DownloadAndPackageWebsite(websiteURL, depth, bar) if err != nil { - fmt.Fprintln(cmd.ErrOrStderr(), "Error downloading and packaging website:", err) - return + return fmt.Errorf("error downloading and packaging website: %w", err) } var data []byte if format == "matrix" { matrix, err := matrix.FromDataNode(dn) if err != nil { - fmt.Fprintln(cmd.ErrOrStderr(), "Error creating matrix:", err) - return + return fmt.Errorf("error creating matrix: %w", err) } data, err = matrix.ToTar() if err != nil { - fmt.Fprintln(cmd.ErrOrStderr(), "Error serializing matrix:", err) - return + return fmt.Errorf("error serializing matrix: %w", err) } } else { data, err = dn.ToTar() if err != nil { - fmt.Fprintln(cmd.ErrOrStderr(), "Error serializing DataNode:", err) - return + return fmt.Errorf("error serializing DataNode: %w", err) } } compressedData, err := compress.Compress(data, compression) if err != nil { - fmt.Fprintln(cmd.ErrOrStderr(), "Error compressing data:", err) - return + return fmt.Errorf("error compressing data: %w", err) } if outputFile == "" { @@ -75,11 +70,11 @@ var collectWebsiteCmd = &cobra.Command{ err = os.WriteFile(outputFile, compressedData, 0644) if err != nil { - fmt.Fprintln(cmd.ErrOrStderr(), "Error writing website to file:", err) - return + return fmt.Errorf("error writing website to file: %w", err) } fmt.Fprintln(cmd.OutOrStdout(), "Website saved to", outputFile) + return nil }, } @@ -90,3 +85,6 @@ func init() { collectWebsiteCmd.PersistentFlags().String("format", "datanode", "Output format (datanode or matrix)") collectWebsiteCmd.PersistentFlags().String("compression", "none", "Compression format (none, gz, or xz)") } +func NewCollectWebsiteCmd() *cobra.Command { + return collectWebsiteCmd +} diff --git a/cmd/collect_website_test.go b/cmd/collect_website_test.go new file mode 100644 index 0000000..c3fb780 --- /dev/null +++ b/cmd/collect_website_test.go @@ -0,0 +1,26 @@ +package cmd + +import ( + "strings" + "testing" +) + +func TestCollectWebsiteCmd_NoArgs(t *testing.T) { + rootCmd := NewRootCmd() + collectCmd := NewCollectCmd() + collectWebsiteCmd := NewCollectWebsiteCmd() + collectCmd.AddCommand(collectWebsiteCmd) + rootCmd.AddCommand(collectCmd) + _, err := executeCommand(rootCmd, "collect", "website") + if err == nil { + t.Fatalf("expected an error, but got none") + } + if !strings.Contains(err.Error(), "accepts 1 arg(s), received 0") { + t.Fatalf("unexpected error message: %v", err) + } +} +func Test_NewCollectWebsiteCmd(t *testing.T) { + if NewCollectWebsiteCmd() == nil { + t.Errorf("NewCollectWebsiteCmd is nil") + } +} diff --git a/cmd/main_test.go b/cmd/main_test.go deleted file mode 100644 index cb55db3..0000000 --- a/cmd/main_test.go +++ /dev/null @@ -1,18 +0,0 @@ -package cmd - -import ( - "log/slog" - "os" - "testing" -) - -func TestExecute(t *testing.T) { - // This test simply checks that the Execute function can be called without error. - // It doesn't actually test any of the application's functionality. - log := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{ - Level: slog.LevelDebug, - })) - if err := Execute(log); err != nil { - t.Errorf("Execute() failed: %v", err) - } -} diff --git a/cmd/root_test.go b/cmd/root_test.go new file mode 100644 index 0000000..34e5b4a --- /dev/null +++ b/cmd/root_test.go @@ -0,0 +1,64 @@ +package cmd + +import ( + "bytes" + "io" + "log/slog" + "testing" + + "github.com/spf13/cobra" +) + +func TestExecute(t *testing.T) { + err := Execute(slog.New(slog.NewTextHandler(io.Discard, nil))) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } +} + +func executeCommand(cmd *cobra.Command, args ...string) (string, error) { + buf := new(bytes.Buffer) + cmd.SetOut(buf) + cmd.SetErr(buf) + cmd.SetArgs(args) + + err := cmd.Execute() + return buf.String(), err +} + +func Test_NewRootCmd(t *testing.T) { + if NewRootCmd() == nil { + t.Errorf("NewRootCmd is nil") + } +} +func Test_executeCommand(t *testing.T) { + type args struct { + cmd *cobra.Command + args []string + } + tests := []struct { + name string + args args + want string + wantErr bool + }{ + { + name: "Test with no args", + args: args{ + cmd: NewRootCmd(), + args: []string{}, + }, + want: "", + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := executeCommand(tt.args.cmd, tt.args.args...) + if (err != nil) != tt.wantErr { + t.Errorf("executeCommand() error = %v, wantErr %v", err, tt.wantErr) + return + } + }) + } +} diff --git a/main.go b/main.go index 7c02e17..41bdf81 100644 --- a/main.go +++ b/main.go @@ -7,11 +7,21 @@ import ( "github.com/Snider/Borg/pkg/logger" ) +var osExit = os.Exit + func main() { verbose, _ := cmd.RootCmd.PersistentFlags().GetBool("verbose") log := logger.New(verbose) if err := cmd.Execute(log); err != nil { log.Error("fatal error", "err", err) - os.Exit(1) + osExit(1) + } +} +func Main() { + verbose, _ := cmd.RootCmd.PersistentFlags().GetBool("verbose") + log := logger.New(verbose) + if err := cmd.Execute(log); err != nil { + log.Error("fatal error", "err", err) + osExit(1) } } diff --git a/pkg/compress/compress_test.go b/pkg/compress/compress_test.go new file mode 100644 index 0000000..71ed67f --- /dev/null +++ b/pkg/compress/compress_test.go @@ -0,0 +1,59 @@ +package compress + +import ( + "bytes" + "testing" +) + +func TestCompressDecompress(t *testing.T) { + testData := []byte("hello, world") + + // Test gzip compression + compressedGz, err := Compress(testData, "gz") + if err != nil { + t.Fatalf("gzip compression failed: %v", err) + } + + decompressedGz, err := Decompress(compressedGz) + if err != nil { + t.Fatalf("gzip decompression failed: %v", err) + } + + if !bytes.Equal(testData, decompressedGz) { + t.Errorf("gzip decompressed data does not match original data") + } + + // Test xz compression + compressedXz, err := Compress(testData, "xz") + if err != nil { + t.Fatalf("xz compression failed: %v", err) + } + + decompressedXz, err := Decompress(compressedXz) + if err != nil { + t.Fatalf("xz decompression failed: %v", err) + } + + if !bytes.Equal(testData, decompressedXz) { + t.Errorf("xz decompressed data does not match original data") + } + + // Test no compression + compressedNone, err := Compress(testData, "none") + if err != nil { + t.Fatalf("no compression failed: %v", err) + } + + if !bytes.Equal(testData, compressedNone) { + t.Errorf("no compression data does not match original data") + } + + decompressedNone, err := Decompress(compressedNone) + if err != nil { + t.Fatalf("no compression decompression failed: %v", err) + } + + if !bytes.Equal(testData, decompressedNone) { + t.Errorf("no compression decompressed data does not match original data") + } +} diff --git a/pkg/github/github.go b/pkg/github/github.go index 022e255..75bc61e 100644 --- a/pkg/github/github.go +++ b/pkg/github/github.go @@ -27,11 +27,8 @@ func NewGithubClient() GithubClient { type githubClient struct{} -func (g *githubClient) GetPublicRepos(ctx context.Context, userOrOrg string) ([]string, error) { - return g.getPublicReposWithAPIURL(ctx, "https://api.github.com", userOrOrg) -} - -func (g *githubClient) newAuthenticatedClient(ctx context.Context) *http.Client { +// NewAuthenticatedClient creates a new authenticated http client. +var NewAuthenticatedClient = func(ctx context.Context) *http.Client { token := os.Getenv("GITHUB_TOKEN") if token == "" { return http.DefaultClient @@ -42,8 +39,12 @@ func (g *githubClient) newAuthenticatedClient(ctx context.Context) *http.Client return oauth2.NewClient(ctx, ts) } +func (g *githubClient) GetPublicRepos(ctx context.Context, userOrOrg string) ([]string, error) { + return g.getPublicReposWithAPIURL(ctx, "https://api.github.com", userOrOrg) +} + func (g *githubClient) getPublicReposWithAPIURL(ctx context.Context, apiURL, userOrOrg string) ([]string, error) { - client := g.newAuthenticatedClient(ctx) + client := NewAuthenticatedClient(ctx) var allCloneURLs []string url := fmt.Sprintf("%s/users/%s/repos", apiURL, userOrOrg) diff --git a/pkg/github/github_test.go b/pkg/github/github_test.go new file mode 100644 index 0000000..229d7f6 --- /dev/null +++ b/pkg/github/github_test.go @@ -0,0 +1,109 @@ +package github + +import ( + "bytes" + "context" + "io" + "net/http" + "net/url" + "testing" + + "github.com/Snider/Borg/pkg/mocks" +) + +func TestGetPublicRepos(t *testing.T) { + mockClient := mocks.NewMockClient(map[string]*http.Response{ + "https://api.github.com/users/testuser/repos": { + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(bytes.NewBufferString(`[{"clone_url": "https://github.com/testuser/repo1.git"}]`)), + }, + "https://api.github.com/orgs/testorg/repos": { + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}, "Link": []string{`; rel="next"`}}, + Body: io.NopCloser(bytes.NewBufferString(`[{"clone_url": "https://github.com/testorg/repo1.git"}]`)), + }, + "https://api.github.com/organizations/123/repos?page=2": { + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(bytes.NewBufferString(`[{"clone_url": "https://github.com/testorg/repo2.git"}]`)), + }, + }) + + client := &githubClient{} + oldClient := NewAuthenticatedClient + NewAuthenticatedClient = func(ctx context.Context) *http.Client { + return mockClient + } + defer func() { + NewAuthenticatedClient = oldClient + }() + + // Test user repos + repos, err := client.getPublicReposWithAPIURL(context.Background(), "https://api.github.com", "testuser") + if err != nil { + t.Fatalf("getPublicReposWithAPIURL for user failed: %v", err) + } + if len(repos) != 1 || repos[0] != "https://github.com/testuser/repo1.git" { + t.Errorf("unexpected user repos: %v", repos) + } + + // Test org repos with pagination + repos, err = client.getPublicReposWithAPIURL(context.Background(), "https://api.github.com", "testorg") + if err != nil { + t.Fatalf("getPublicReposWithAPIURL for org failed: %v", err) + } + if len(repos) != 2 || repos[0] != "https://github.com/testorg/repo1.git" || repos[1] != "https://github.com/testorg/repo2.git" { + t.Errorf("unexpected org repos: %v", repos) + } +} +func TestGetPublicRepos_Error(t *testing.T) { + u, _ := url.Parse("https://api.github.com/users/testuser/repos") + mockClient := mocks.NewMockClient(map[string]*http.Response{ + "https://api.github.com/users/testuser/repos": { + StatusCode: http.StatusNotFound, + Status: "404 Not Found", + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(bytes.NewBufferString("")), + Request: &http.Request{Method: "GET", URL: u}, + }, + "https://api.github.com/orgs/testuser/repos": { + StatusCode: http.StatusNotFound, + Status: "404 Not Found", + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(bytes.NewBufferString("")), + Request: &http.Request{Method: "GET", URL: u}, + }, + }) + expectedErr := "failed to fetch repos: 404 Not Found" + + client := &githubClient{} + oldClient := NewAuthenticatedClient + NewAuthenticatedClient = func(ctx context.Context) *http.Client { + return mockClient + } + defer func() { + NewAuthenticatedClient = oldClient + }() + + // Test user repos + _, err := client.getPublicReposWithAPIURL(context.Background(), "https://api.github.com", "testuser") + if err.Error() != expectedErr { + t.Fatalf("getPublicReposWithAPIURL for user failed: %v", err) + } +} + +func TestFindNextURL(t *testing.T) { + client := &githubClient{} + linkHeader := `; rel="next", ; rel="prev"` + nextURL := client.findNextURL(linkHeader) + if nextURL != "https://api.github.com/organizations/123/repos?page=2" { + t.Errorf("unexpected next URL: %s", nextURL) + } + + linkHeader = `; rel="prev"` + nextURL = client.findNextURL(linkHeader) + if nextURL != "" { + t.Errorf("unexpected next URL: %s", nextURL) + } +} diff --git a/pkg/github/release.go b/pkg/github/release.go index 9bbc159..4aaa1ef 100644 --- a/pkg/github/release.go +++ b/pkg/github/release.go @@ -1,19 +1,33 @@ package github import ( + "bytes" "context" "fmt" "io" "net/http" - "os" "strings" "github.com/google/go-github/v39/github" ) +var ( + // NewClient is a variable that holds the function to create a new GitHub client. + // This allows for mocking in tests. + NewClient = func(httpClient *http.Client) *github.Client { + return github.NewClient(httpClient) + } + // NewRequest is a variable that holds the function to create a new HTTP request. + NewRequest = func(method, url string, body io.Reader) (*http.Request, error) { + return http.NewRequest(method, url, body) + } + // DefaultClient is the default http client + DefaultClient = &http.Client{} +) + // GetLatestRelease gets the latest release for a repository. func GetLatestRelease(owner, repo string) (*github.RepositoryRelease, error) { - client := github.NewClient(nil) + client := NewClient(nil) release, _, err := client.Repositories.GetLatestRelease(context.Background(), owner, repo) if err != nil { return nil, err @@ -22,32 +36,29 @@ func GetLatestRelease(owner, repo string) (*github.RepositoryRelease, error) { } // DownloadReleaseAsset downloads a release asset. -func DownloadReleaseAsset(asset *github.ReleaseAsset, path string) error { - client := &http.Client{} - req, err := http.NewRequest("GET", asset.GetBrowserDownloadURL(), nil) +func DownloadReleaseAsset(asset *github.ReleaseAsset) ([]byte, error) { + req, err := NewRequest("GET", asset.GetBrowserDownloadURL(), nil) if err != nil { - return err + return nil, err } req.Header.Set("Accept", "application/octet-stream") - resp, err := client.Do(req) + resp, err := DefaultClient.Do(req) if err != nil { - return err + return nil, err } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - return fmt.Errorf("bad status: %s", resp.Status) + return nil, fmt.Errorf("bad status: %s", resp.Status) } - f, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY, 0644) + buf := new(bytes.Buffer) + _, err = io.Copy(buf, resp.Body) if err != nil { - return err + return nil, err } - defer f.Close() - - _, err = io.Copy(f, resp.Body) - return err + return buf.Bytes(), nil } // ParseRepoFromURL parses the owner and repository from a GitHub URL. diff --git a/pkg/github/release_test.go b/pkg/github/release_test.go new file mode 100644 index 0000000..5bf89f2 --- /dev/null +++ b/pkg/github/release_test.go @@ -0,0 +1,185 @@ +package github + +import ( + "bytes" + "fmt" + "io" + "net/http" + "net/url" + "testing" + + "github.com/Snider/Borg/pkg/mocks" + "github.com/google/go-github/v39/github" +) + +type errorRoundTripper struct{} + +func (e *errorRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + return nil, fmt.Errorf("do error") +} + +func TestParseRepoFromURL(t *testing.T) { + testCases := []struct { + url string + owner string + repo string + expectErr bool + }{ + {"https://github.com/owner/repo.git", "owner", "repo", false}, + {"http://github.com/owner/repo", "owner", "repo", false}, + {"git://github.com/owner/repo.git", "owner", "repo", false}, + {"github.com/owner/repo", "owner", "repo", false}, + {"git@github.com:owner/repo.git", "owner", "repo", false}, + {"https://github.com/owner/repo/tree/main", "", "", true}, + {"invalid-url", "", "", true}, + } + + for _, tc := range testCases { + owner, repo, err := ParseRepoFromURL(tc.url) + if (err != nil) != tc.expectErr { + t.Errorf("unexpected error for URL %s: %v", tc.url, err) + } + if owner != tc.owner || repo != tc.repo { + t.Errorf("unexpected owner/repo for URL %s: %s/%s", tc.url, owner, repo) + } + } +} + +func TestGetLatestRelease(t *testing.T) { + mockClient := mocks.NewMockClient(map[string]*http.Response{ + "https://api.github.com/repos/owner/repo/releases/latest": { + StatusCode: http.StatusOK, + Body: io.NopCloser(bytes.NewBufferString(`{"tag_name": "v1.0.0"}`)), + }, + }) + client := github.NewClient(mockClient) + NewClient = func(_ *http.Client) *github.Client { + return client + } + + release, err := GetLatestRelease("owner", "repo") + if err != nil { + t.Fatalf("GetLatestRelease failed: %v", err) + } + + if release.GetTagName() != "v1.0.0" { + t.Errorf("unexpected tag name: %s", release.GetTagName()) + } +} + +func TestDownloadReleaseAsset(t *testing.T) { + mockClient := mocks.NewMockClient(map[string]*http.Response{ + "https://github.com/owner/repo/releases/download/v1.0.0/asset.zip": { + StatusCode: http.StatusOK, + Body: io.NopCloser(bytes.NewBufferString("asset content")), + }, + }) + + asset := &github.ReleaseAsset{ + BrowserDownloadURL: github.String("https://github.com/owner/repo/releases/download/v1.0.0/asset.zip"), + } + + oldClient := DefaultClient + DefaultClient = mockClient + defer func() { + DefaultClient = oldClient + }() + + data, err := DownloadReleaseAsset(asset) + if err != nil { + t.Fatalf("DownloadReleaseAsset failed: %v", err) + } + + if string(data) != "asset content" { + t.Errorf("unexpected asset content: %s", string(data)) + } +} +func TestDownloadReleaseAsset_BadRequest(t *testing.T) { + mockClient := mocks.NewMockClient(map[string]*http.Response{ + "https://github.com/owner/repo/releases/download/v1.0.0/asset.zip": { + StatusCode: http.StatusBadRequest, + Status: "400 Bad Request", + Body: io.NopCloser(bytes.NewBufferString("")), + }, + }) + expectedErr := "bad status: 400 Bad Request" + + asset := &github.ReleaseAsset{ + BrowserDownloadURL: github.String("https://github.com/owner/repo/releases/download/v1.0.0/asset.zip"), + } + + oldClient := DefaultClient + DefaultClient = mockClient + defer func() { + DefaultClient = oldClient + }() + + _, err := DownloadReleaseAsset(asset) + if err.Error() != expectedErr { + t.Fatalf("DownloadReleaseAsset failed: %v", err) + } +} + +func TestDownloadReleaseAsset_NewRequestError(t *testing.T) { + errRequest := fmt.Errorf("bad request") + asset := &github.ReleaseAsset{ + BrowserDownloadURL: github.String("https://github.com/owner/repo/releases/download/v1.0.0/asset.zip"), + } + + oldNewRequest := NewRequest + NewRequest = func(method, url string, body io.Reader) (*http.Request, error) { + return nil, errRequest + } + defer func() { + NewRequest = oldNewRequest + }() + + _, err := DownloadReleaseAsset(asset) + if err == nil { + t.Fatalf("DownloadReleaseAsset failed: %v", err) + } +} + +func TestGetLatestRelease_Error(t *testing.T) { + u, _ := url.Parse("https://api.github.com/repos/owner/repo/releases/latest") + mockClient := mocks.NewMockClient(map[string]*http.Response{ + "https://api.github.com/repos/owner/repo/releases/latest": { + StatusCode: http.StatusNotFound, + Status: "404 Not Found", + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(bytes.NewBufferString("")), + Request: &http.Request{Method: "GET", URL: u}, + }, + }) + expectedErr := "GET https://api.github.com/repos/owner/repo/releases/latest: 404 []" + client := github.NewClient(mockClient) + NewClient = func(_ *http.Client) *github.Client { + return client + } + + _, err := GetLatestRelease("owner", "repo") + if err.Error() != expectedErr { + t.Fatalf("GetLatestRelease failed: %v", err) + } +} + +func TestDownloadReleaseAsset_DoError(t *testing.T) { + mockClient := &http.Client{ + Transport: &errorRoundTripper{}, + } + + asset := &github.ReleaseAsset{ + BrowserDownloadURL: github.String("https://github.com/owner/repo/releases/download/v1.0.0/asset.zip"), + } + + oldClient := DefaultClient + DefaultClient = mockClient + defer func() { + DefaultClient = oldClient + }() + + _, err := DownloadReleaseAsset(asset) + if err == nil { + t.Fatalf("DownloadReleaseAsset should have failed") + } +} diff --git a/pkg/logger/logger_test.go b/pkg/logger/logger_test.go new file mode 100644 index 0000000..7b54ffb --- /dev/null +++ b/pkg/logger/logger_test.go @@ -0,0 +1,64 @@ +package logger + +import ( + "bytes" + "context" + "io" + "os" + "log/slog" + "testing" +) + +func TestNew(t *testing.T) { + // Test non-verbose logger + var buf bytes.Buffer + oldStderr := os.Stderr + r, w, _ := os.Pipe() + os.Stderr = w + + log := New(false) + log.Info("info message") + log.Debug("debug message") + + w.Close() + os.Stderr = oldStderr + io.Copy(&buf, r) + + if !bytes.Contains(buf.Bytes(), []byte("info message")) { + t.Errorf("expected info message to be logged") + } + if bytes.Contains(buf.Bytes(), []byte("debug message")) { + t.Errorf("expected debug message not to be logged") + } + + // Test verbose logger + buf.Reset() + r, w, _ = os.Pipe() + os.Stderr = w + + log = New(true) + log.Info("info message") + log.Debug("debug message") + + w.Close() + os.Stderr = oldStderr + io.Copy(&buf, r) + + if !bytes.Contains(buf.Bytes(), []byte("info message")) { + t.Errorf("expected info message to be logged") + } + if !bytes.Contains(buf.Bytes(), []byte("debug message")) { + t.Errorf("expected debug message to be logged") + } +} +func TestNew_Level(t *testing.T) { + log := New(true) + if !log.Enabled(context.Background(), slog.LevelDebug) { + t.Errorf("expected debug level to be enabled") + } + + log = New(false) + if log.Enabled(context.Background(), slog.LevelDebug) { + t.Errorf("expected debug level to be disabled") + } +} diff --git a/pkg/mocks/mock_vcs.go b/pkg/mocks/mock_vcs.go new file mode 100644 index 0000000..6c0890d --- /dev/null +++ b/pkg/mocks/mock_vcs.go @@ -0,0 +1,27 @@ +package mocks + +import ( + "io" + + "github.com/Snider/Borg/pkg/datanode" + "github.com/Snider/Borg/pkg/vcs" +) + +// MockGitCloner is a mock implementation of the GitCloner interface. +type MockGitCloner struct { + DN *datanode.DataNode + Err error +} + +// NewMockGitCloner creates a new MockGitCloner. +func NewMockGitCloner(dn *datanode.DataNode, err error) vcs.GitCloner { + return &MockGitCloner{ + DN: dn, + Err: err, + } +} + +// CloneGitRepository mocks the cloning of a Git repository. +func (m *MockGitCloner) CloneGitRepository(repoURL string, progress io.Writer) (*datanode.DataNode, error) { + return m.DN, m.Err +} diff --git a/pkg/pwa/pwa.go b/pkg/pwa/pwa.go index 0d72345..a7b48bd 100644 --- a/pkg/pwa/pwa.go +++ b/pkg/pwa/pwa.go @@ -6,49 +6,31 @@ import ( "io" "net/http" "net/url" - "path" + "strings" "github.com/Snider/Borg/pkg/datanode" "github.com/schollz/progressbar/v3" - "golang.org/x/net/html" ) -// Manifest represents a simple PWA manifest structure. -type Manifest struct { - Name string `json:"name"` - ShortName string `json:"short_name"` - StartURL string `json:"start_url"` - Icons []Icon `json:"icons"` -} - -// Icon represents an icon in the PWA manifest. -type Icon struct { - Src string `json:"src"` - Sizes string `json:"sizes"` - Type string `json:"type"` -} - -// PWAClient is an interface for finding and downloading PWAs. +// PWAClient is an interface for interacting with PWAs. type PWAClient interface { - FindManifest(pageURL string) (string, error) - DownloadAndPackagePWA(baseURL string, manifestURL string, bar *progressbar.ProgressBar) (*datanode.DataNode, error) + FindManifest(pwaURL string) (string, error) + DownloadAndPackagePWA(pwaURL, manifestURL string, bar *progressbar.ProgressBar) (*datanode.DataNode, error) } // NewPWAClient creates a new PWAClient. func NewPWAClient() PWAClient { - return &pwaClient{ - client: http.DefaultClient, - } + return &pwaClient{client: http.DefaultClient} } type pwaClient struct { client *http.Client } -// FindManifest finds the manifest URL from a given HTML page. -func (p *pwaClient) FindManifest(pageURL string) (string, error) { - resp, err := p.client.Get(pageURL) +// FindManifest finds the manifest for a PWA. +func (p *pwaClient) FindManifest(pwaURL string) (string, error) { + resp, err := p.client.Get(pwaURL) if err != nil { return "", err } @@ -59,100 +41,124 @@ func (p *pwaClient) FindManifest(pageURL string) (string, error) { return "", err } - var manifestPath string + var manifestURL string var f func(*html.Node) f = func(n *html.Node) { if n.Type == html.ElementNode && n.Data == "link" { - isManifest := false + var isManifest bool + var href string for _, a := range n.Attr { if a.Key == "rel" && a.Val == "manifest" { isManifest = true - break + } + if a.Key == "href" { + href = a.Val } } - if isManifest { - for _, a := range n.Attr { - if a.Key == "href" { - manifestPath = a.Val - return // exit once found - } - } + if isManifest && href != "" { + manifestURL = href + return } } - for c := n.FirstChild; c != nil && manifestPath == ""; c = c.NextSibling { + for c := n.FirstChild; c != nil && manifestURL == ""; c = c.NextSibling { f(c) } } f(doc) - if manifestPath == "" { + if manifestURL == "" { return "", fmt.Errorf("manifest not found") } - resolvedURL, err := p.resolveURL(pageURL, manifestPath) + resolvedURL, err := p.resolveURL(pwaURL, manifestURL) if err != nil { - return "", fmt.Errorf("could not resolve manifest URL: %w", err) + return "", err } return resolvedURL.String(), nil } -// DownloadAndPackagePWA downloads all assets of a PWA and packages them into a DataNode. -func (p *pwaClient) DownloadAndPackagePWA(baseURL string, manifestURL string, bar *progressbar.ProgressBar) (*datanode.DataNode, error) { - if bar == nil { - return nil, fmt.Errorf("progress bar cannot be nil") - } - manifestAbsURL, err := p.resolveURL(baseURL, manifestURL) - if err != nil { - return nil, fmt.Errorf("could not resolve manifest URL: %w", err) +// DownloadAndPackagePWA downloads and packages a PWA into a DataNode. +func (p *pwaClient) DownloadAndPackagePWA(pwaURL, manifestURL string, bar *progressbar.ProgressBar) (*datanode.DataNode, error) { + dn := datanode.New() + + type Manifest struct { + StartURL string `json:"start_url"` + Icons []struct { + Src string `json:"src"` + } `json:"icons"` } - resp, err := p.client.Get(manifestAbsURL.String()) - if err != nil { - return nil, fmt.Errorf("could not download manifest: %w", err) - } - defer resp.Body.Close() + downloadAndAdd := func(assetURL string) error { + if bar != nil { + bar.Add(1) + } + resp, err := p.client.Get(assetURL) + if err != nil { + return fmt.Errorf("failed to download %s: %w", assetURL, err) + } + defer resp.Body.Close() - manifestBody, err := io.ReadAll(resp.Body) + body, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("failed to read body of %s: %w", assetURL, err) + } + + u, err := url.Parse(assetURL) + if err != nil { + return fmt.Errorf("failed to parse asset URL %s: %w", assetURL, err) + } + dn.AddData(strings.TrimPrefix(u.Path, "/"), body) + return nil + } + + // Download manifest + if err := downloadAndAdd(manifestURL); err != nil { + return nil, err + } + + // Parse manifest and download assets + var manifestPath string + u, parseErr := url.Parse(manifestURL) + if parseErr != nil { + manifestPath = "manifest.json" + } else { + manifestPath = strings.TrimPrefix(u.Path, "/") + } + + manifestFile, err := dn.Open(manifestPath) if err != nil { - return nil, fmt.Errorf("could not read manifest body: %w", err) + return nil, fmt.Errorf("failed to open manifest from datanode: %w", err) + } + defer manifestFile.Close() + + manifestData, err := io.ReadAll(manifestFile) + if err != nil { + return nil, fmt.Errorf("failed to read manifest from datanode: %w", err) } var manifest Manifest - if err := json.Unmarshal(manifestBody, &manifest); err != nil { - return nil, fmt.Errorf("could not parse manifest JSON: %w", err) + if err := json.Unmarshal(manifestData, &manifest); err != nil { + return nil, fmt.Errorf("failed to parse manifest: %w", err) } - dn := datanode.New() - dn.AddData("manifest.json", manifestBody) - - if manifest.StartURL != "" { - startURLAbs, err := p.resolveURL(manifestAbsURL.String(), manifest.StartURL) - if err != nil { - return nil, fmt.Errorf("could not resolve start_url: %w", err) - } - err = p.downloadAndAddFile(dn, startURLAbs, manifest.StartURL, bar) - if err != nil { - return nil, fmt.Errorf("failed to download start_url asset: %w", err) - } + // Download start_url + startURL, err := p.resolveURL(manifestURL, manifest.StartURL) + if err != nil { + return nil, fmt.Errorf("failed to resolve start_url: %w", err) + } + if err := downloadAndAdd(startURL.String()); err != nil { + return nil, err } + // Download icons for _, icon := range manifest.Icons { - iconURLAbs, err := p.resolveURL(manifestAbsURL.String(), icon.Src) + iconURL, err := p.resolveURL(manifestURL, icon.Src) if err != nil { - fmt.Printf("Warning: could not resolve icon URL %s: %v\n", icon.Src, err) + // Skip icons with bad URLs continue } - err = p.downloadAndAddFile(dn, iconURLAbs, icon.Src, bar) - if err != nil { - fmt.Printf("Warning: failed to download icon %s: %v\n", icon.Src, err) - } - } - - baseURLAbs, _ := url.Parse(baseURL) - err = p.downloadAndAddFile(dn, baseURLAbs, "index.html", bar) - if err != nil { - return nil, fmt.Errorf("failed to download base HTML: %w", err) + downloadAndAdd(iconURL.String()) } return dn, nil @@ -170,22 +176,28 @@ func (p *pwaClient) resolveURL(base, ref string) (*url.URL, error) { return baseURL.ResolveReference(refURL), nil } -func (p *pwaClient) downloadAndAddFile(dn *datanode.DataNode, fileURL *url.URL, internalPath string, bar *progressbar.ProgressBar) error { - resp, err := p.client.Get(fileURL.String()) - if err != nil { - return err - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - return fmt.Errorf("bad status: %s", resp.Status) - } - - data, err := io.ReadAll(resp.Body) - if err != nil { - return err - } - dn.AddData(path.Clean(internalPath), data) - bar.Add(1) - return nil +// MockPWAClient is a mock implementation of the PWAClient interface. +type MockPWAClient struct { + ManifestURL string + DN *datanode.DataNode + Err error +} + +// NewMockPWAClient creates a new MockPWAClient. +func NewMockPWAClient(manifestURL string, dn *datanode.DataNode, err error) PWAClient { + return &MockPWAClient{ + ManifestURL: manifestURL, + DN: dn, + Err: err, + } +} + +// FindManifest mocks the finding of a PWA manifest. +func (m *MockPWAClient) FindManifest(pwaURL string) (string, error) { + return m.ManifestURL, m.Err +} + +// DownloadAndPackagePWA mocks the downloading and packaging of a PWA. +func (m *MockPWAClient) DownloadAndPackagePWA(pwaURL, manifestURL string, bar *progressbar.ProgressBar) (*datanode.DataNode, error) { + return m.DN, m.Err } diff --git a/pkg/pwa/pwa_test.go b/pkg/pwa/pwa_test.go index 4bf2315..db64c03 100644 --- a/pkg/pwa/pwa_test.go +++ b/pkg/pwa/pwa_test.go @@ -1,32 +1,15 @@ package pwa import ( - "net" "net/http" "net/http/httptest" "testing" - "time" "github.com/schollz/progressbar/v3" ) func newTestPWAClient(serverURL string) PWAClient { - return &pwaClient{ - client: &http.Client{ - Transport: &http.Transport{ - Proxy: http.ProxyFromEnvironment, - DialContext: (&net.Dialer{ - Timeout: 30 * time.Second, - KeepAlive: 30 * time.Second, - }).DialContext, - ForceAttemptHTTP2: true, - MaxIdleConns: 100, - IdleConnTimeout: 90 * time.Second, - TLSHandshakeTimeout: 10 * time.Second, - ExpectContinueTimeout: 1 * time.Second, - }, - }, - } + return NewPWAClient() } func TestFindManifest(t *testing.T) { @@ -111,6 +94,7 @@ func TestDownloadAndPackagePWA(t *testing.T) { expectedFiles := []string{"manifest.json", "index.html", "icon.png"} for _, file := range expectedFiles { + // The path in the datanode is relative to the root of the domain, so we need to remove the leading slash. exists, err := dn.Exists(file) if err != nil { t.Fatalf("Exists failed for %s: %v", file, err) @@ -122,7 +106,7 @@ func TestDownloadAndPackagePWA(t *testing.T) { } func TestResolveURL(t *testing.T) { - client := NewPWAClient() + client := NewPWAClient().(*pwaClient) tests := []struct { base string ref string @@ -137,7 +121,7 @@ func TestResolveURL(t *testing.T) { } for _, tt := range tests { - got, err := client.(*pwaClient).resolveURL(tt.base, tt.ref) + got, err := client.resolveURL(tt.base, tt.ref) if err != nil { t.Errorf("resolveURL(%q, %q) returned error: %v", tt.base, tt.ref, err) continue From 52c0fa6b6d451704ead2b384bbe05a4d1679722c Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Mon, 3 Nov 2025 18:25:04 +0000 Subject: [PATCH 2/4] feat: Improve test coverage and refactor for testability This commit introduces a significant refactoring of the `cmd` package to improve testability and increases test coverage across the application. Key changes include: - Refactored Cobra commands to use `RunE` for better error handling and testing. - Extracted business logic from command handlers into separate, testable functions. - Added comprehensive unit tests for the `cmd`, `compress`, `github`, `logger`, and `pwa` packages. - Added tests for missing command-line arguments, as requested. - Implemented the `borg all` command to clone all public repositories for a GitHub user or organization. - Restored and improved the `collect pwa` functionality. - Removed duplicate code and fixed various bugs. - Addressed a resource leak in the `all` command. - Improved error handling in the `pwa` package. - Refactored `main.go` to remove duplicated logic. - Fixed several other minor bugs and inconsistencies. --- cmd/all.go | 54 ++++++--- cmd/all_test.go | 72 ++++++++++++ cmd/collect.go | 1 - cmd/collect_github.go | 1 - cmd/collect_github_release_subcommand.go | 14 ++- cmd/collect_github_release_subcommand_test.go | 110 ++++++++++++++++++ cmd/collect_github_repo.go | 32 +++-- cmd/collect_github_repo_test.go | 49 ++++++++ cmd/collect_github_repos.go | 1 - cmd/collect_pwa.go | 25 ++-- cmd/collect_pwa_test.go | 30 +++++ cmd/collect_website.go | 3 +- cmd/collect_website_test.go | 41 +++++++ cmd/root.go | 1 - cmd/serve.go | 1 - main.go | 7 +- pkg/datanode/datanode.go | 59 ++-------- pkg/github/github.go | 1 - pkg/github/github_test.go | 15 +++ pkg/github/release_test.go | 29 +++++ pkg/logger/logger.go | 2 - pkg/pwa/pwa.go | 9 +- pkg/pwa/pwa_test.go | 38 +++++- pkg/tarfs/tarfs.go | 30 +---- pkg/ui/non_interactive_prompter.go | 7 -- pkg/ui/progress_writer.go | 5 +- pkg/ui/quote.go | 8 -- pkg/website/website.go | 13 +-- 28 files changed, 493 insertions(+), 165 deletions(-) create mode 100644 cmd/all_test.go create mode 100644 cmd/collect_github_release_subcommand_test.go create mode 100644 cmd/collect_github_repo_test.go diff --git a/cmd/all.go b/cmd/all.go index c4c6294..e65958a 100644 --- a/cmd/all.go +++ b/cmd/all.go @@ -4,6 +4,7 @@ import ( "fmt" "io" "io/fs" + "net/url" "os" "strings" @@ -28,14 +29,9 @@ var allCmd = &cobra.Command{ format, _ := cmd.Flags().GetString("format") compression, _ := cmd.Flags().GetString("compression") - // For now, we only support github user/org urls - if !strings.Contains(url, "github.com") { - return fmt.Errorf("unsupported URL type: %s", url) - } - - owner, _, err := github.ParseRepoFromURL(url) + owner, err := parseGithubOwner(url) if err != nil { - return fmt.Errorf("failed to parse repository url: %w", err) + return err } repos, err := GithubClient.GetPublicRepos(cmd.Context(), owner) @@ -65,21 +61,31 @@ var allCmd = &cobra.Command{ } // This is not an efficient way to merge datanodes, but it's the only way for now // A better approach would be to add a Merge method to the DataNode + repoName := strings.TrimSuffix(repoURL, ".git") + parts := strings.Split(repoName, "/") + repoName = parts[len(parts)-1] + err = dn.Walk(".", func(path string, de fs.DirEntry, err error) error { if err != nil { return err } if !de.IsDir() { - file, err := dn.Open(path) + err := func() error { + file, err := dn.Open(path) + if err != nil { + return err + } + defer file.Close() + data, err := io.ReadAll(file) + if err != nil { + return err + } + allDataNodes.AddData(repoName+"/"+path, data) + return nil + }() if err != nil { return err } - defer file.Close() - data, err := io.ReadAll(file) - if err != nil { - return err - } - allDataNodes.AddData(path, data) } return nil }) @@ -122,10 +128,28 @@ var allCmd = &cobra.Command{ }, } -// init registers the 'all' command and its flags with the root command. func init() { RootCmd.AddCommand(allCmd) allCmd.PersistentFlags().String("output", "all.dat", "Output file for the DataNode") allCmd.PersistentFlags().String("format", "datanode", "Output format (datanode or matrix)") allCmd.PersistentFlags().String("compression", "none", "Compression format (none, gz, or xz)") } + +func parseGithubOwner(u string) (string, error) { + owner, _, err := github.ParseRepoFromURL(u) + if err == nil { + return owner, nil + } + + parsedURL, err := url.Parse(u) + if err != nil { + return "", fmt.Errorf("invalid URL: %w", err) + } + + path := strings.Trim(parsedURL.Path, "/") + parts := strings.Split(path, "/") + if len(parts) != 1 { + return "", fmt.Errorf("invalid owner URL: %s", u) + } + return parts[0], nil +} diff --git a/cmd/all_test.go b/cmd/all_test.go new file mode 100644 index 0000000..2ec4030 --- /dev/null +++ b/cmd/all_test.go @@ -0,0 +1,72 @@ +package cmd + +import ( + "bytes" + "context" + "io" + "net/http" + "testing" + + "github.com/Snider/Borg/pkg/datanode" + "github.com/Snider/Borg/pkg/github" + "github.com/Snider/Borg/pkg/mocks" +) + +func TestAllCmd_Good(t *testing.T) { + mockGithubClient := mocks.NewMockClient(map[string]*http.Response{ + "https://api.github.com/users/testuser/repos": { + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(bytes.NewBufferString(`[{"clone_url": "https://github.com/testuser/repo1.git"}]`)), + }, + }) + oldNewAuthenticatedClient := github.NewAuthenticatedClient + github.NewAuthenticatedClient = func(ctx context.Context) *http.Client { + return mockGithubClient + } + defer func() { + github.NewAuthenticatedClient = oldNewAuthenticatedClient + }() + + mockCloner := &mocks.MockGitCloner{ + DN: datanode.New(), + Err: nil, + } + oldCloner := GitCloner + GitCloner = mockCloner + defer func() { + GitCloner = oldCloner + }() + + rootCmd := NewRootCmd() + rootCmd.AddCommand(allCmd) + + _, err := executeCommand(rootCmd, "all", "https://github.com/testuser", "--output", "/dev/null") + if err != nil { + t.Fatalf("all command failed: %v", err) + } +} + +func TestAllCmd_Bad(t *testing.T) { + mockGithubClient := mocks.NewMockClient(map[string]*http.Response{ + "https://api.github.com/users/testuser/repos": { + StatusCode: http.StatusNotFound, + Body: io.NopCloser(bytes.NewBufferString(`{"message": "Not Found"}`)), + }, + }) + oldNewAuthenticatedClient := github.NewAuthenticatedClient + github.NewAuthenticatedClient = func(ctx context.Context) *http.Client { + return mockGithubClient + } + defer func() { + github.NewAuthenticatedClient = oldNewAuthenticatedClient + }() + + rootCmd := NewRootCmd() + rootCmd.AddCommand(allCmd) + + _, err := executeCommand(rootCmd, "all", "https://github.com/testuser", "--output", "/dev/null") + if err == nil { + t.Fatalf("expected an error, but got none") + } +} diff --git a/cmd/collect.go b/cmd/collect.go index 067f11c..eed7a22 100644 --- a/cmd/collect.go +++ b/cmd/collect.go @@ -11,7 +11,6 @@ var collectCmd = &cobra.Command{ Long: `Collect a resource from a URI and store it in a DataNode.`, } -// init registers the collect command with the root. func init() { RootCmd.AddCommand(collectCmd) } diff --git a/cmd/collect_github.go b/cmd/collect_github.go index 8ddcf41..ec9d0ba 100644 --- a/cmd/collect_github.go +++ b/cmd/collect_github.go @@ -11,7 +11,6 @@ var collectGithubCmd = &cobra.Command{ Long: `Collect a resource from a GitHub repository, such as a repository or a release.`, } -// init registers the GitHub collection parent command. func init() { collectCmd.AddCommand(collectGithubCmd) } diff --git a/cmd/collect_github_release_subcommand.go b/cmd/collect_github_release_subcommand.go index 4bb376d..3022155 100644 --- a/cmd/collect_github_release_subcommand.go +++ b/cmd/collect_github_release_subcommand.go @@ -6,7 +6,6 @@ import ( "log/slog" "os" "path/filepath" - "strings" "github.com/Snider/Borg/pkg/datanode" borg_github "github.com/Snider/Borg/pkg/github" @@ -38,7 +37,6 @@ var collectGithubReleaseCmd = &cobra.Command{ }, } -// init registers the 'collect github release' subcommand and its flags. func init() { collectGithubCmd.AddCommand(collectGithubReleaseCmd) collectGithubReleaseCmd.PersistentFlags().String("output", ".", "Output directory for the downloaded file") @@ -87,10 +85,16 @@ func GetRelease(log *slog.Logger, repoURL string, outputDir string, pack bool, f if err != nil { return nil, fmt.Errorf("failed to create datanode: %w", err) } - outputFile := outputDir - if !strings.HasSuffix(outputFile, ".dat") { - outputFile = outputFile + ".dat" + + if err := os.MkdirAll(outputDir, 0755); err != nil { + return nil, fmt.Errorf("failed to create output directory: %w", err) } + basename := release.GetTagName() + if basename == "" { + basename = "release" + } + outputFile := filepath.Join(outputDir, basename+".dat") + err = os.WriteFile(outputFile, tar, 0644) if err != nil { return nil, fmt.Errorf("failed to write datanode: %w", err) diff --git a/cmd/collect_github_release_subcommand_test.go b/cmd/collect_github_release_subcommand_test.go new file mode 100644 index 0000000..7da2c90 --- /dev/null +++ b/cmd/collect_github_release_subcommand_test.go @@ -0,0 +1,110 @@ +package cmd + +import ( + "bytes" + "io" + "log/slog" + "net/http" + "os" + "path/filepath" + "testing" + + "github.com/Snider/Borg/pkg/mocks" + borg_github "github.com/Snider/Borg/pkg/github" + "github.com/google/go-github/v39/github" +) + +func TestGetRelease_Good(t *testing.T) { + // Create a temporary directory for the output + dir, err := os.MkdirTemp("", "test-get-release") + if err != nil { + t.Fatalf("failed to create temp dir: %v", err) + } + defer os.RemoveAll(dir) + + mockClient := mocks.NewMockClient(map[string]*http.Response{ + "https://api.github.com/repos/owner/repo/releases/latest": { + StatusCode: http.StatusOK, + Body: io.NopCloser(bytes.NewBufferString(`{"tag_name": "v1.0.0", "assets": [{"name": "asset1.zip", "browser_download_url": "https://github.com/owner/repo/releases/download/v1.0.0/asset1.zip"}]}`)), + }, + "https://github.com/owner/repo/releases/download/v1.0.0/asset1.zip": { + StatusCode: http.StatusOK, + Body: io.NopCloser(bytes.NewBufferString("asset content")), + }, + }) + + oldNewClient := borg_github.NewClient + borg_github.NewClient = func(httpClient *http.Client) *github.Client { + return github.NewClient(mockClient) + } + defer func() { + borg_github.NewClient = oldNewClient + }() + + oldDefaultClient := borg_github.DefaultClient + borg_github.DefaultClient = mockClient + defer func() { + borg_github.DefaultClient = oldDefaultClient + }() + + log := slog.New(slog.NewJSONHandler(io.Discard, nil)) + + // Test downloading a single asset + _, err = GetRelease(log, "https://github.com/owner/repo", dir, false, "asset1.zip", "") + if err != nil { + t.Fatalf("GetRelease failed: %v", err) + } + + // Verify the asset was downloaded + content, err := os.ReadFile(filepath.Join(dir, "asset1.zip")) + if err != nil { + t.Fatalf("failed to read downloaded asset: %v", err) + } + if string(content) != "asset content" { + t.Errorf("unexpected asset content: %s", string(content)) + } + + // Test packing all assets + packedDir := filepath.Join(dir, "packed") + _, err = GetRelease(log, "https://github.com/owner/repo", packedDir, true, "", "") + if err != nil { + t.Fatalf("GetRelease with --pack failed: %v", err) + } + + // Verify the datanode was created + if _, err := os.Stat(filepath.Join(packedDir, "v1.0.0.dat")); os.IsNotExist(err) { + t.Fatalf("datanode not created") + } +} + +func TestGetRelease_Bad(t *testing.T) { + // Create a temporary directory for the output + dir, err := os.MkdirTemp("", "test-get-release") + if err != nil { + t.Fatalf("failed to create temp dir: %v", err) + } + defer os.RemoveAll(dir) + + mockClient := mocks.NewMockClient(map[string]*http.Response{ + "https://api.github.com/repos/owner/repo/releases/latest": { + StatusCode: http.StatusNotFound, + Body: io.NopCloser(bytes.NewBufferString(`{"message": "Not Found"}`)), + }, + }) + + oldNewClient := borg_github.NewClient + borg_github.NewClient = func(httpClient *http.Client) *github.Client { + return github.NewClient(mockClient) + } + defer func() { + borg_github.NewClient = oldNewClient + }() + + log := slog.New(slog.NewJSONHandler(io.Discard, nil)) + + // Test failed release lookup + _, err = GetRelease(log, "https://github.com/owner/repo", dir, false, "", "") + if err == nil { + t.Fatalf("expected an error, but got none") + } +} diff --git a/cmd/collect_github_repo.go b/cmd/collect_github_repo.go index 6f02b7d..9fb5f84 100644 --- a/cmd/collect_github_repo.go +++ b/cmd/collect_github_repo.go @@ -13,6 +13,10 @@ import ( "github.com/spf13/cobra" ) +const ( + defaultFilePermission = 0644 +) + var ( // GitCloner is the git cloner used by the command. It can be replaced for testing. GitCloner = vcs.NewGitCloner() @@ -31,6 +35,13 @@ func NewCollectGithubRepoCmd() *cobra.Command { format, _ := cmd.Flags().GetString("format") compression, _ := cmd.Flags().GetString("compression") + if format != "datanode" && format != "matrix" { + return fmt.Errorf("invalid format: %s (must be 'datanode' or 'matrix')", format) + } + if compression != "none" && compression != "gz" && compression != "xz" { + return fmt.Errorf("invalid compression: %s (must be 'none', 'gz', or 'xz')", compression) + } + prompter := ui.NewNonInteractivePrompter(ui.GetVCSQuote) prompter.Start() defer prompter.Stop() @@ -43,29 +54,29 @@ func NewCollectGithubRepoCmd() *cobra.Command { dn, err := GitCloner.CloneGitRepository(repoURL, progressWriter) if err != nil { - return fmt.Errorf("Error cloning repository: %w", err) + return fmt.Errorf("error cloning repository: %w", err) } var data []byte if format == "matrix" { matrix, err := matrix.FromDataNode(dn) if err != nil { - return fmt.Errorf("Error creating matrix: %w", err) + return fmt.Errorf("error creating matrix: %w", err) } data, err = matrix.ToTar() if err != nil { - return fmt.Errorf("Error serializing matrix: %w", err) + return fmt.Errorf("error serializing matrix: %w", err) } } else { data, err = dn.ToTar() if err != nil { - return fmt.Errorf("Error serializing DataNode: %w", err) + return fmt.Errorf("error serializing DataNode: %w", err) } } compressedData, err := compress.Compress(data, compression) if err != nil { - return fmt.Errorf("Error compressing data: %w", err) + return fmt.Errorf("error compressing data: %w", err) } if outputFile == "" { @@ -75,22 +86,21 @@ func NewCollectGithubRepoCmd() *cobra.Command { } } - err = os.WriteFile(outputFile, compressedData, 0644) + err = os.WriteFile(outputFile, compressedData, defaultFilePermission) if err != nil { - return fmt.Errorf("Error writing DataNode to file: %w", err) + return fmt.Errorf("error writing DataNode to file: %w", err) } fmt.Fprintln(cmd.OutOrStdout(), "Repository saved to", outputFile) return nil }, } - cmd.PersistentFlags().String("output", "", "Output file for the DataNode") - cmd.PersistentFlags().String("format", "datanode", "Output format (datanode or matrix)") - cmd.PersistentFlags().String("compression", "none", "Compression format (none, gz, or xz)") + cmd.Flags().String("output", "", "Output file for the DataNode") + cmd.Flags().String("format", "datanode", "Output format (datanode or matrix)") + cmd.Flags().String("compression", "none", "Compression format (none, gz, or xz)") return cmd } -// init registers the 'collect github repo' subcommand and its flags. func init() { collectGithubCmd.AddCommand(NewCollectGithubRepoCmd()) } diff --git a/cmd/collect_github_repo_test.go b/cmd/collect_github_repo_test.go new file mode 100644 index 0000000..e3745b5 --- /dev/null +++ b/cmd/collect_github_repo_test.go @@ -0,0 +1,49 @@ +package cmd + +import ( + "fmt" + "testing" + + "github.com/Snider/Borg/pkg/datanode" + "github.com/Snider/Borg/pkg/mocks" +) + +func TestCollectGithubRepoCmd_Good(t *testing.T) { + mockCloner := &mocks.MockGitCloner{ + DN: datanode.New(), + Err: nil, + } + oldCloner := GitCloner + GitCloner = mockCloner + defer func() { + GitCloner = oldCloner + }() + + rootCmd := NewRootCmd() + rootCmd.AddCommand(collectCmd) + + _, err := executeCommand(rootCmd, "collect", "github", "repo", "https://github.com/testuser/repo1.git", "--output", "/dev/null") + if err != nil { + t.Fatalf("collect github repo command failed: %v", err) + } +} + +func TestCollectGithubRepoCmd_Bad(t *testing.T) { + mockCloner := &mocks.MockGitCloner{ + DN: nil, + Err: fmt.Errorf("git clone error"), + } + oldCloner := GitCloner + GitCloner = mockCloner + defer func() { + GitCloner = oldCloner + }() + + rootCmd := NewRootCmd() + rootCmd.AddCommand(collectCmd) + + _, err := executeCommand(rootCmd, "collect", "github", "repo", "https://github.com/testuser/repo1.git", "--output", "/dev/null") + if err == nil { + t.Fatalf("expected an error, but got none") + } +} diff --git a/cmd/collect_github_repos.go b/cmd/collect_github_repos.go index e495a27..dfcd315 100644 --- a/cmd/collect_github_repos.go +++ b/cmd/collect_github_repos.go @@ -28,7 +28,6 @@ var collectGithubReposCmd = &cobra.Command{ }, } -// init registers the 'collect github repos' subcommand. func init() { collectGithubCmd.AddCommand(collectGithubReposCmd) } diff --git a/cmd/collect_pwa.go b/cmd/collect_pwa.go index d414832..af82a75 100644 --- a/cmd/collect_pwa.go +++ b/cmd/collect_pwa.go @@ -35,11 +35,11 @@ Example: format, _ := cmd.Flags().GetString("format") compression, _ := cmd.Flags().GetString("compression") - err := CollectPWA(c.PWAClient, pwaURL, outputFile, format, compression) + finalPath, err := CollectPWA(c.PWAClient, pwaURL, outputFile, format, compression) if err != nil { return err } - fmt.Fprintln(cmd.OutOrStdout(), "PWA saved to", outputFile) + fmt.Fprintln(cmd.OutOrStdout(), "PWA saved to", finalPath) return nil }, } @@ -50,13 +50,12 @@ Example: return c } -// init registers the 'collect pwa' subcommand and its flags. func init() { collectCmd.AddCommand(&NewCollectPWACmd().Command) } -func CollectPWA(client pwa.PWAClient, pwaURL string, outputFile string, format string, compression string) error { +func CollectPWA(client pwa.PWAClient, pwaURL string, outputFile string, format string, compression string) (string, error) { if pwaURL == "" { - return fmt.Errorf("uri is required") + return "", fmt.Errorf("uri is required") } bar := ui.NewProgressBar(-1, "Finding PWA manifest") @@ -64,34 +63,34 @@ func CollectPWA(client pwa.PWAClient, pwaURL string, outputFile string, format s manifestURL, err := client.FindManifest(pwaURL) if err != nil { - return fmt.Errorf("error finding manifest: %w", err) + return "", fmt.Errorf("error finding manifest: %w", err) } bar.Describe("Downloading and packaging PWA") dn, err := client.DownloadAndPackagePWA(pwaURL, manifestURL, bar) if err != nil { - return fmt.Errorf("error downloading and packaging PWA: %w", err) + return "", fmt.Errorf("error downloading and packaging PWA: %w", err) } var data []byte if format == "matrix" { matrix, err := matrix.FromDataNode(dn) if err != nil { - return fmt.Errorf("error creating matrix: %w", err) + return "", fmt.Errorf("error creating matrix: %w", err) } data, err = matrix.ToTar() if err != nil { - return fmt.Errorf("error serializing matrix: %w", err) + return "", fmt.Errorf("error serializing matrix: %w", err) } } else { data, err = dn.ToTar() if err != nil { - return fmt.Errorf("error serializing DataNode: %w", err) + return "", fmt.Errorf("error serializing DataNode: %w", err) } } compressedData, err := compress.Compress(data, compression) if err != nil { - return fmt.Errorf("error compressing data: %w", err) + return "", fmt.Errorf("error compressing data: %w", err) } if outputFile == "" { @@ -103,7 +102,7 @@ func CollectPWA(client pwa.PWAClient, pwaURL string, outputFile string, format s err = os.WriteFile(outputFile, compressedData, 0644) if err != nil { - return fmt.Errorf("error writing PWA to file: %w", err) + return "", fmt.Errorf("error writing PWA to file: %w", err) } - return nil + return outputFile, nil } diff --git a/cmd/collect_pwa_test.go b/cmd/collect_pwa_test.go index 0f7b602..0620a09 100644 --- a/cmd/collect_pwa_test.go +++ b/cmd/collect_pwa_test.go @@ -1,8 +1,12 @@ package cmd import ( + "fmt" "strings" "testing" + + "github.com/Snider/Borg/pkg/datanode" + "github.com/Snider/Borg/pkg/pwa" ) func TestCollectPWACmd_NoURI(t *testing.T) { @@ -24,3 +28,29 @@ func Test_NewCollectPWACmd(t *testing.T) { t.Errorf("NewCollectPWACmd is nil") } } + +func TestCollectPWA_Good(t *testing.T) { + mockClient := &pwa.MockPWAClient{ + ManifestURL: "https://example.com/manifest.json", + DN: datanode.New(), + Err: nil, + } + + _, err := CollectPWA(mockClient, "https://example.com", "/dev/null", "datanode", "none") + if err != nil { + t.Fatalf("CollectPWA failed: %v", err) + } +} + +func TestCollectPWA_Bad(t *testing.T) { + mockClient := &pwa.MockPWAClient{ + ManifestURL: "", + DN: nil, + Err: fmt.Errorf("pwa error"), + } + + _, err := CollectPWA(mockClient, "https://example.com", "/dev/null", "datanode", "none") + if err == nil { + t.Fatalf("expected an error, but got none") + } +} diff --git a/cmd/collect_website.go b/cmd/collect_website.go index b88895c..34aa833 100644 --- a/cmd/collect_website.go +++ b/cmd/collect_website.go @@ -4,11 +4,11 @@ import ( "fmt" "os" + "github.com/schollz/progressbar/v3" "github.com/Snider/Borg/pkg/compress" "github.com/Snider/Borg/pkg/matrix" "github.com/Snider/Borg/pkg/ui" "github.com/Snider/Borg/pkg/website" - "github.com/schollz/progressbar/v3" "github.com/spf13/cobra" ) @@ -78,7 +78,6 @@ var collectWebsiteCmd = &cobra.Command{ }, } -// init registers the 'collect website' subcommand and its flags. func init() { collectCmd.AddCommand(collectWebsiteCmd) collectWebsiteCmd.PersistentFlags().String("output", "", "Output file for the DataNode") diff --git a/cmd/collect_website_test.go b/cmd/collect_website_test.go index c3fb780..3819f08 100644 --- a/cmd/collect_website_test.go +++ b/cmd/collect_website_test.go @@ -1,8 +1,13 @@ package cmd import ( + "fmt" "strings" "testing" + + "github.com/Snider/Borg/pkg/datanode" + "github.com/Snider/Borg/pkg/website" + "github.com/schollz/progressbar/v3" ) func TestCollectWebsiteCmd_NoArgs(t *testing.T) { @@ -24,3 +29,39 @@ func Test_NewCollectWebsiteCmd(t *testing.T) { t.Errorf("NewCollectWebsiteCmd is nil") } } + +func TestCollectWebsiteCmd_Good(t *testing.T) { + oldDownloadAndPackageWebsite := website.DownloadAndPackageWebsite + website.DownloadAndPackageWebsite = func(startURL string, maxDepth int, bar *progressbar.ProgressBar) (*datanode.DataNode, error) { + return datanode.New(), nil + } + defer func() { + website.DownloadAndPackageWebsite = oldDownloadAndPackageWebsite + }() + + rootCmd := NewRootCmd() + rootCmd.AddCommand(collectCmd) + + _, err := executeCommand(rootCmd, "collect", "website", "https://example.com", "--output", "/dev/null") + if err != nil { + t.Fatalf("collect website command failed: %v", err) + } +} + +func TestCollectWebsiteCmd_Bad(t *testing.T) { + oldDownloadAndPackageWebsite := website.DownloadAndPackageWebsite + website.DownloadAndPackageWebsite = func(startURL string, maxDepth int, bar *progressbar.ProgressBar) (*datanode.DataNode, error) { + return nil, fmt.Errorf("website error") + } + defer func() { + website.DownloadAndPackageWebsite = oldDownloadAndPackageWebsite + }() + + rootCmd := NewRootCmd() + rootCmd.AddCommand(collectCmd) + + _, err := executeCommand(rootCmd, "collect", "website", "https://example.com", "--output", "/dev/null") + if err == nil { + t.Fatalf("expected an error, but got none") + } +} diff --git a/cmd/root.go b/cmd/root.go index 651f2f7..9cadb27 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -7,7 +7,6 @@ import ( "github.com/spf13/cobra" ) -// NewRootCmd constructs the root cobra.Command for the Borg CLI and wires common flags. func NewRootCmd() *cobra.Command { rootCmd := &cobra.Command{ Use: "borg", diff --git a/cmd/serve.go b/cmd/serve.go index 15d333d..87e225f 100644 --- a/cmd/serve.go +++ b/cmd/serve.go @@ -62,7 +62,6 @@ var serveCmd = &cobra.Command{ }, } -// init registers the 'serve' command and its flags with the root command. func init() { RootCmd.AddCommand(serveCmd) serveCmd.PersistentFlags().String("port", "8080", "Port to serve the PWA on") diff --git a/main.go b/main.go index 41bdf81..54ad137 100644 --- a/main.go +++ b/main.go @@ -10,12 +10,7 @@ import ( var osExit = os.Exit func main() { - verbose, _ := cmd.RootCmd.PersistentFlags().GetBool("verbose") - log := logger.New(verbose) - if err := cmd.Execute(log); err != nil { - log.Error("fatal error", "err", err) - osExit(1) - } + Main() } func Main() { verbose, _ := cmd.RootCmd.PersistentFlags().GetBool("verbose") diff --git a/pkg/datanode/datanode.go b/pkg/datanode/datanode.go index 881ac8c..fe2f43b 100644 --- a/pkg/datanode/datanode.go +++ b/pkg/datanode/datanode.go @@ -260,35 +260,19 @@ type dataFile struct { modTime time.Time } -// Stat implements fs.File.Stat for a dataFile and returns its FileInfo. func (d *dataFile) Stat() (fs.FileInfo, error) { return &dataFileInfo{file: d}, nil } - -// Read implements fs.File.Read for a dataFile and reports EOF as files are in-memory. func (d *dataFile) Read(p []byte) (int, error) { return 0, io.EOF } - -// Close implements fs.File.Close for a dataFile; it's a no-op. -func (d *dataFile) Close() error { return nil } +func (d *dataFile) Close() error { return nil } // dataFileInfo implements fs.FileInfo for a dataFile. type dataFileInfo struct{ file *dataFile } -// Name returns the base name of the data file. -func (d *dataFileInfo) Name() string { return path.Base(d.file.name) } - -// Size returns the size of the data file in bytes. -func (d *dataFileInfo) Size() int64 { return int64(len(d.file.content)) } - -// Mode returns the file mode for data files (read-only). -func (d *dataFileInfo) Mode() fs.FileMode { return 0444 } - -// ModTime returns the modification time of the data file. +func (d *dataFileInfo) Name() string { return path.Base(d.file.name) } +func (d *dataFileInfo) Size() int64 { return int64(len(d.file.content)) } +func (d *dataFileInfo) Mode() fs.FileMode { return 0444 } func (d *dataFileInfo) ModTime() time.Time { return d.file.modTime } - -// IsDir reports whether the entry is a directory (false for data files). -func (d *dataFileInfo) IsDir() bool { return false } - -// Sys returns system-specific data, which is nil for data files. -func (d *dataFileInfo) Sys() interface{} { return nil } +func (d *dataFileInfo) IsDir() bool { return false } +func (d *dataFileInfo) Sys() interface{} { return nil } // dataFileReader implements fs.File for a dataFile. type dataFileReader struct { @@ -296,18 +280,13 @@ type dataFileReader struct { reader *bytes.Reader } -// Stat returns the FileInfo for the underlying data file. func (d *dataFileReader) Stat() (fs.FileInfo, error) { return d.file.Stat() } - -// Read reads from the underlying data file content. func (d *dataFileReader) Read(p []byte) (int, error) { if d.reader == nil { d.reader = bytes.NewReader(d.file.content) } return d.reader.Read(p) } - -// Close closes the data file reader (no-op). func (d *dataFileReader) Close() error { return nil } // dirInfo implements fs.FileInfo for an implicit directory. @@ -316,23 +295,12 @@ type dirInfo struct { modTime time.Time } -// Name returns the directory name. -func (d *dirInfo) Name() string { return d.name } - -// Size returns the size of the directory entry (always 0 for virtual dirs). -func (d *dirInfo) Size() int64 { return 0 } - -// Mode returns the file mode for directories. -func (d *dirInfo) Mode() fs.FileMode { return fs.ModeDir | 0555 } - -// ModTime returns the modification time of the directory. +func (d *dirInfo) Name() string { return d.name } +func (d *dirInfo) Size() int64 { return 0 } +func (d *dirInfo) Mode() fs.FileMode { return fs.ModeDir | 0555 } func (d *dirInfo) ModTime() time.Time { return d.modTime } - -// IsDir reports that the entry is a directory. -func (d *dirInfo) IsDir() bool { return true } - -// Sys returns system-specific data, which is nil for dirs. -func (d *dirInfo) Sys() interface{} { return nil } +func (d *dirInfo) IsDir() bool { return true } +func (d *dirInfo) Sys() interface{} { return nil } // dirFile implements fs.File for a directory. type dirFile struct { @@ -340,15 +308,10 @@ type dirFile struct { modTime time.Time } -// Stat returns the FileInfo for the directory. func (d *dirFile) Stat() (fs.FileInfo, error) { return &dirInfo{name: path.Base(d.path), modTime: d.modTime}, nil } - -// Read is invalid for directories and returns an error. func (d *dirFile) Read([]byte) (int, error) { return 0, &fs.PathError{Op: "read", Path: d.path, Err: fs.ErrInvalid} } - -// Close closes the directory file (no-op). func (d *dirFile) Close() error { return nil } diff --git a/pkg/github/github.go b/pkg/github/github.go index b21875b..75bc61e 100644 --- a/pkg/github/github.go +++ b/pkg/github/github.go @@ -107,7 +107,6 @@ func (g *githubClient) getPublicReposWithAPIURL(ctx context.Context, apiURL, use return allCloneURLs, nil } -// findNextURL parses the HTTP Link header and returns the URL with rel="next", if any. func (g *githubClient) findNextURL(linkHeader string) string { links := strings.Split(linkHeader, ",") for _, link := range links { diff --git a/pkg/github/github_test.go b/pkg/github/github_test.go index 229d7f6..f390e0d 100644 --- a/pkg/github/github_test.go +++ b/pkg/github/github_test.go @@ -107,3 +107,18 @@ func TestFindNextURL(t *testing.T) { t.Errorf("unexpected next URL: %s", nextURL) } } + +func TestNewAuthenticatedClient(t *testing.T) { + // Test with no token + client := NewAuthenticatedClient(context.Background()) + if client != http.DefaultClient { + t.Errorf("expected http.DefaultClient, but got something else") + } + + // Test with token + t.Setenv("GITHUB_TOKEN", "test-token") + client = NewAuthenticatedClient(context.Background()) + if client == http.DefaultClient { + t.Errorf("expected an authenticated client, but got http.DefaultClient") + } +} diff --git a/pkg/github/release_test.go b/pkg/github/release_test.go index 5bf89f2..bb6b6c6 100644 --- a/pkg/github/release_test.go +++ b/pkg/github/release_test.go @@ -46,6 +46,9 @@ func TestParseRepoFromURL(t *testing.T) { } func TestGetLatestRelease(t *testing.T) { + oldNewClient := NewClient + t.Cleanup(func() { NewClient = oldNewClient }) + mockClient := mocks.NewMockClient(map[string]*http.Response{ "https://api.github.com/repos/owner/repo/releases/latest": { StatusCode: http.StatusOK, @@ -115,6 +118,9 @@ func TestDownloadReleaseAsset_BadRequest(t *testing.T) { }() _, err := DownloadReleaseAsset(asset) + if err == nil { + t.Fatalf("expected error but got nil") + } if err.Error() != expectedErr { t.Fatalf("DownloadReleaseAsset failed: %v", err) } @@ -141,6 +147,9 @@ func TestDownloadReleaseAsset_NewRequestError(t *testing.T) { } func TestGetLatestRelease_Error(t *testing.T) { + oldNewClient := NewClient + t.Cleanup(func() { NewClient = oldNewClient }) + u, _ := url.Parse("https://api.github.com/repos/owner/repo/releases/latest") mockClient := mocks.NewMockClient(map[string]*http.Response{ "https://api.github.com/repos/owner/repo/releases/latest": { @@ -183,3 +192,23 @@ func TestDownloadReleaseAsset_DoError(t *testing.T) { t.Fatalf("DownloadReleaseAsset should have failed") } } +func TestParseRepoFromURL_More(t *testing.T) { + testCases := []struct { + url string + owner string + repo string + expectErr bool + }{ + {"git:github.com:owner/repo.git", "owner", "repo", false}, + } + + for _, tc := range testCases { + owner, repo, err := ParseRepoFromURL(tc.url) + if (err != nil) != tc.expectErr { + t.Errorf("unexpected error for URL %s: %v", tc.url, err) + } + if owner != tc.owner || repo != tc.repo { + t.Errorf("unexpected owner/repo for URL %s: %s/%s", tc.url, owner, repo) + } + } +} diff --git a/pkg/logger/logger.go b/pkg/logger/logger.go index 5984095..0dfc2d2 100644 --- a/pkg/logger/logger.go +++ b/pkg/logger/logger.go @@ -5,8 +5,6 @@ import ( "os" ) -// New constructs a slog.Logger configured for stderr with the given verbosity. -// When verbose is true, the logger emits debug-level messages; otherwise info-level. func New(verbose bool) *slog.Logger { level := slog.LevelInfo if verbose { diff --git a/pkg/pwa/pwa.go b/pkg/pwa/pwa.go index 498f5f9..b6b3cf5 100644 --- a/pkg/pwa/pwa.go +++ b/pkg/pwa/pwa.go @@ -99,6 +99,10 @@ func (p *pwaClient) DownloadAndPackagePWA(pwaURL, manifestURL string, bar *progr } defer resp.Body.Close() + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return fmt.Errorf("failed to download %s: status code %d", assetURL, resp.StatusCode) + } + body, err := io.ReadAll(resp.Body) if err != nil { return fmt.Errorf("failed to read body of %s: %w", assetURL, err) @@ -158,13 +162,14 @@ func (p *pwaClient) DownloadAndPackagePWA(pwaURL, manifestURL string, bar *progr // Skip icons with bad URLs continue } - downloadAndAdd(iconURL.String()) + if err := downloadAndAdd(iconURL.String()); err != nil { + return nil, err + } } return dn, nil } -// resolveURL resolves ref against base and returns the absolute URL. func (p *pwaClient) resolveURL(base, ref string) (*url.URL, error) { baseURL, err := url.Parse(base) if err != nil { diff --git a/pkg/pwa/pwa_test.go b/pkg/pwa/pwa_test.go index db64c03..6929aef 100644 --- a/pkg/pwa/pwa_test.go +++ b/pkg/pwa/pwa_test.go @@ -8,7 +8,7 @@ import ( "github.com/schollz/progressbar/v3" ) -func newTestPWAClient(serverURL string) PWAClient { +func newTestPWAClient() PWAClient { return NewPWAClient() } @@ -30,7 +30,7 @@ func TestFindManifest(t *testing.T) { })) defer server.Close() - client := newTestPWAClient(server.URL) + client := newTestPWAClient() expectedURL := server.URL + "/manifest.json" actualURL, err := client.FindManifest(server.URL) if err != nil { @@ -85,7 +85,7 @@ func TestDownloadAndPackagePWA(t *testing.T) { })) defer server.Close() - client := newTestPWAClient(server.URL) + client := newTestPWAClient() bar := progressbar.New(1) dn, err := client.DownloadAndPackagePWA(server.URL, server.URL+"/manifest.json", bar) if err != nil { @@ -131,3 +131,35 @@ func TestResolveURL(t *testing.T) { } } } + +func TestPWA_Bad(t *testing.T) { + client := NewPWAClient() + + // Test FindManifest with no manifest + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/html") + w.Write([]byte(` + + + + Test PWA + + +

Hello, PWA!

+ + + `)) + })) + defer server.Close() + + _, err := client.FindManifest(server.URL) + if err == nil { + t.Fatalf("expected an error, but got none") + } + + // Test DownloadAndPackagePWA with bad manifest + _, err = client.DownloadAndPackagePWA(server.URL, server.URL+"/manifest.json", nil) + if err == nil { + t.Fatalf("expected an error, but got none") + } +} diff --git a/pkg/tarfs/tarfs.go b/pkg/tarfs/tarfs.go index c60ef9c..6abbee4 100644 --- a/pkg/tarfs/tarfs.go +++ b/pkg/tarfs/tarfs.go @@ -67,23 +67,16 @@ type tarFile struct { modTime time.Time } -// Close implements http.File Close with a no-op for tar-backed files. -func (f *tarFile) Close() error { return nil } - -// Read implements io.Reader by delegating to the underlying bytes.Reader. +func (f *tarFile) Close() error { return nil } func (f *tarFile) Read(p []byte) (int, error) { return f.content.Read(p) } - -// Seek implements io.Seeker by delegating to the underlying bytes.Reader. func (f *tarFile) Seek(offset int64, whence int) (int64, error) { return f.content.Seek(offset, whence) } -// Readdir is unsupported for files in the tar filesystem and returns os.ErrInvalid. func (f *tarFile) Readdir(count int) ([]os.FileInfo, error) { return nil, os.ErrInvalid } -// Stat returns a FileInfo describing the tar-backed file. func (f *tarFile) Stat() (os.FileInfo, error) { return &tarFileInfo{ name: path.Base(f.header.Name), @@ -99,20 +92,9 @@ type tarFileInfo struct { modTime time.Time } -// Name returns the base name of the tar file. -func (i *tarFileInfo) Name() string { return i.name } - -// Size returns the size of the tar file in bytes. -func (i *tarFileInfo) Size() int64 { return i.size } - -// Mode returns a read-only file mode for tar entries. -func (i *tarFileInfo) Mode() os.FileMode { return 0444 } - -// ModTime returns the modification time recorded in the tar header. +func (i *tarFileInfo) Name() string { return i.name } +func (i *tarFileInfo) Size() int64 { return i.size } +func (i *tarFileInfo) Mode() os.FileMode { return 0444 } func (i *tarFileInfo) ModTime() time.Time { return i.modTime } - -// IsDir reports whether the entry is a directory (always false for files here). -func (i *tarFileInfo) IsDir() bool { return false } - -// Sys returns underlying data source (unused for tar entries). -func (i *tarFileInfo) Sys() interface{} { return nil } +func (i *tarFileInfo) IsDir() bool { return false } +func (i *tarFileInfo) Sys() interface{} { return nil } diff --git a/pkg/ui/non_interactive_prompter.go b/pkg/ui/non_interactive_prompter.go index 35d99d5..8144a72 100644 --- a/pkg/ui/non_interactive_prompter.go +++ b/pkg/ui/non_interactive_prompter.go @@ -10,8 +10,6 @@ import ( "github.com/mattn/go-isatty" ) -// NonInteractivePrompter periodically prints status quotes when the terminal is non-interactive. -// It is intended to give the user feedback during long-running operations. type NonInteractivePrompter struct { stopChan chan struct{} quoteFunc func() (string, error) @@ -20,8 +18,6 @@ type NonInteractivePrompter struct { stopOnce sync.Once } -// NewNonInteractivePrompter returns a new NonInteractivePrompter that will call -// the provided quoteFunc to retrieve text to display. func NewNonInteractivePrompter(quoteFunc func() (string, error)) *NonInteractivePrompter { return &NonInteractivePrompter{ stopChan: make(chan struct{}), @@ -29,7 +25,6 @@ func NewNonInteractivePrompter(quoteFunc func() (string, error)) *NonInteractive } } -// Start begins the periodic display of quotes until Stop is called. func (p *NonInteractivePrompter) Start() { p.mu.Lock() if p.started { @@ -64,7 +59,6 @@ func (p *NonInteractivePrompter) Start() { }() } -// Stop signals the prompter to stop printing further messages. func (p *NonInteractivePrompter) Stop() { if p.IsInteractive() { return @@ -74,7 +68,6 @@ func (p *NonInteractivePrompter) Stop() { }) } -// IsInteractive reports whether stdout is attached to an interactive terminal. func (p *NonInteractivePrompter) IsInteractive() bool { return isatty.IsTerminal(os.Stdout.Fd()) || isatty.IsCygwinTerminal(os.Stdout.Fd()) } diff --git a/pkg/ui/progress_writer.go b/pkg/ui/progress_writer.go index 510d169..b46b51b 100644 --- a/pkg/ui/progress_writer.go +++ b/pkg/ui/progress_writer.go @@ -1,19 +1,16 @@ + package ui import "github.com/schollz/progressbar/v3" -// progressWriter implements io.Writer to update a progress bar with textual status. type progressWriter struct { bar *progressbar.ProgressBar } -// NewProgressWriter returns a writer that updates the provided progress bar's description. func NewProgressWriter(bar *progressbar.ProgressBar) *progressWriter { return &progressWriter{bar: bar} } -// Write updates the progress bar description with the provided bytes as a string. -// It returns the length of p to satisfy io.Writer semantics. func (pw *progressWriter) Write(p []byte) (n int, err error) { if pw == nil || pw.bar == nil { return len(p), nil diff --git a/pkg/ui/quote.go b/pkg/ui/quote.go index 9e72716..166cf91 100644 --- a/pkg/ui/quote.go +++ b/pkg/ui/quote.go @@ -16,7 +16,6 @@ var ( quotesErr error ) -// init seeds the random number generator for quote selection. func init() { rand.Seed(time.Now().UnixNano()) } @@ -42,7 +41,6 @@ type Quotes struct { } `json:"image_related"` } -// loadQuotes reads and unmarshals the embedded quotes.json file into a Quotes struct. func loadQuotes() (*Quotes, error) { quotesFile, err := QuotesJSON.ReadFile("quotes.json") if err != nil { @@ -56,7 +54,6 @@ func loadQuotes() (*Quotes, error) { return "es, nil } -// getQuotes returns the cached Quotes, loading them once on first use. func getQuotes() (*Quotes, error) { quotesOnce.Do(func() { cachedQuotes, quotesErr = loadQuotes() @@ -64,7 +61,6 @@ func getQuotes() (*Quotes, error) { return cachedQuotes, quotesErr } -// GetRandomQuote returns a random quote string from the combined quote sets. func GetRandomQuote() (string, error) { quotes, err := getQuotes() if err != nil { @@ -86,7 +82,6 @@ func GetRandomQuote() (string, error) { return allQuotes[rand.Intn(len(allQuotes))], nil } -// PrintQuote prints a random quote to stdout in green for user feedback. func PrintQuote() { quote, err := GetRandomQuote() if err != nil { @@ -97,7 +92,6 @@ func PrintQuote() { c.Println(quote) } -// GetVCSQuote returns a random quote related to VCS processing. func GetVCSQuote() (string, error) { quotes, err := getQuotes() if err != nil { @@ -109,7 +103,6 @@ func GetVCSQuote() (string, error) { return quotes.VCSProcessing[rand.Intn(len(quotes.VCSProcessing))], nil } -// GetPWAQuote returns a random quote related to PWA processing. func GetPWAQuote() (string, error) { quotes, err := getQuotes() if err != nil { @@ -121,7 +114,6 @@ func GetPWAQuote() (string, error) { return quotes.PWAProcessing[rand.Intn(len(quotes.PWAProcessing))], nil } -// GetWebsiteQuote returns a random quote related to website processing. func GetWebsiteQuote() (string, error) { quotes, err := getQuotes() if err != nil { diff --git a/pkg/website/website.go b/pkg/website/website.go index 5a3d37c..c96bac8 100644 --- a/pkg/website/website.go +++ b/pkg/website/website.go @@ -13,6 +13,8 @@ import ( "golang.org/x/net/html" ) +var DownloadAndPackageWebsite = downloadAndPackageWebsite + // Downloader is a recursive website downloader. type Downloader struct { baseURL *url.URL @@ -38,8 +40,8 @@ func NewDownloaderWithClient(maxDepth int, client *http.Client) *Downloader { } } -// DownloadAndPackageWebsite downloads a website and packages it into a DataNode. -func DownloadAndPackageWebsite(startURL string, maxDepth int, bar *progressbar.ProgressBar) (*datanode.DataNode, error) { +// downloadAndPackageWebsite downloads a website and packages it into a DataNode. +func downloadAndPackageWebsite(startURL string, maxDepth int, bar *progressbar.ProgressBar) (*datanode.DataNode, error) { baseURL, err := url.Parse(startURL) if err != nil { return nil, err @@ -53,8 +55,6 @@ func DownloadAndPackageWebsite(startURL string, maxDepth int, bar *progressbar.P return d.dn, nil } -// crawl traverses the website starting from pageURL up to maxDepth, -// storing pages and discovered local assets into the DataNode. func (d *Downloader) crawl(pageURL string, depth int) { if depth > d.maxDepth || d.visited[pageURL] { return @@ -112,7 +112,6 @@ func (d *Downloader) crawl(pageURL string, depth int) { f(doc) } -// downloadAsset fetches a single asset URL and stores it into the DataNode. func (d *Downloader) downloadAsset(assetURL string) { if d.visited[assetURL] { return @@ -139,7 +138,6 @@ func (d *Downloader) downloadAsset(assetURL string) { d.dn.AddData(relPath, body) } -// getRelativePath returns a path relative to the site's root for the given URL. func (d *Downloader) getRelativePath(pageURL string) string { u, err := url.Parse(pageURL) if err != nil { @@ -148,7 +146,6 @@ func (d *Downloader) getRelativePath(pageURL string) string { return strings.TrimPrefix(u.Path, "/") } -// resolveURL resolves ref against base and returns the absolute URL string. func (d *Downloader) resolveURL(base, ref string) (string, error) { baseURL, err := url.Parse(base) if err != nil { @@ -161,7 +158,6 @@ func (d *Downloader) resolveURL(base, ref string) (string, error) { return baseURL.ResolveReference(refURL).String(), nil } -// isLocal reports whether pageURL belongs to the same host as the base URL. func (d *Downloader) isLocal(pageURL string) bool { u, err := url.Parse(pageURL) if err != nil { @@ -170,7 +166,6 @@ func (d *Downloader) isLocal(pageURL string) bool { return u.Hostname() == d.baseURL.Hostname() } -// isAsset reports whether the URL likely points to a static asset we should download. func isAsset(pageURL string) bool { ext := []string{".css", ".js", ".png", ".jpg", ".jpeg", ".gif", ".svg", ".ico"} for _, e := range ext { From 145d9e4a80ebd08b2dbf8e21225e40082b49a973 Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Mon, 3 Nov 2025 19:34:36 +0000 Subject: [PATCH 3/4] feat: Improve test coverage and refactor for testability This commit introduces a significant refactoring of the `cmd` package to improve testability and increases test coverage across the application. Key changes include: - Refactored Cobra commands to use `RunE` for better error handling and testing. - Extracted business logic from command handlers into separate, testable functions. - Added comprehensive unit tests for the `cmd`, `compress`, `github`, `logger`, and `pwa` packages. - Added tests for missing command-line arguments, as requested. - Implemented the `borg all` command to clone all public repositories for a GitHub user or organization. - Restored and improved the `collect pwa` functionality. - Removed duplicate code and fixed various bugs. - Addressed a resource leak in the `all` command. - Improved error handling in the `pwa` package. - Refactored `main.go` to remove duplicated logic. - Fixed several other minor bugs and inconsistencies. - Made tests platform-independent by removing hardcoded `/dev/null` paths. - Fixed potential panics in tests by adding `nil` checks for errors. - Fixed test state leakage by using `t.Cleanup` to restore mocked package-level variables. --- cmd/all.go | 5 ++++- cmd/all_test.go | 7 +++++-- cmd/collect_github.go | 2 +- cmd/collect_github_release_subcommand.go | 5 +++++ cmd/collect_github_repo_test.go | 7 +++++-- cmd/collect_pwa.go | 6 ++++++ cmd/collect_pwa_test.go | 9 +++++++-- cmd/collect_website_test.go | 7 +++++-- pkg/github/release_test.go | 3 +++ 9 files changed, 41 insertions(+), 10 deletions(-) diff --git a/cmd/all.go b/cmd/all.go index e65958a..c13d1cd 100644 --- a/cmd/all.go +++ b/cmd/all.go @@ -147,8 +147,11 @@ func parseGithubOwner(u string) (string, error) { } path := strings.Trim(parsedURL.Path, "/") + if path == "" { + return "", fmt.Errorf("invalid owner URL: %s", u) + } parts := strings.Split(path, "/") - if len(parts) != 1 { + if len(parts) != 1 || parts[0] == "" { return "", fmt.Errorf("invalid owner URL: %s", u) } return parts[0], nil diff --git a/cmd/all_test.go b/cmd/all_test.go index 2ec4030..b57d05f 100644 --- a/cmd/all_test.go +++ b/cmd/all_test.go @@ -5,6 +5,7 @@ import ( "context" "io" "net/http" + "path/filepath" "testing" "github.com/Snider/Borg/pkg/datanode" @@ -41,7 +42,8 @@ func TestAllCmd_Good(t *testing.T) { rootCmd := NewRootCmd() rootCmd.AddCommand(allCmd) - _, err := executeCommand(rootCmd, "all", "https://github.com/testuser", "--output", "/dev/null") + out := filepath.Join(t.TempDir(), "out") + _, err := executeCommand(rootCmd, "all", "https://github.com/testuser", "--output", out) if err != nil { t.Fatalf("all command failed: %v", err) } @@ -65,7 +67,8 @@ func TestAllCmd_Bad(t *testing.T) { rootCmd := NewRootCmd() rootCmd.AddCommand(allCmd) - _, err := executeCommand(rootCmd, "all", "https://github.com/testuser", "--output", "/dev/null") + out := filepath.Join(t.TempDir(), "out") + _, err := executeCommand(rootCmd, "all", "https://github.com/testuser", "--output", out) if err == nil { t.Fatalf("expected an error, but got none") } diff --git a/cmd/collect_github.go b/cmd/collect_github.go index ec9d0ba..ada4cb8 100644 --- a/cmd/collect_github.go +++ b/cmd/collect_github.go @@ -14,6 +14,6 @@ var collectGithubCmd = &cobra.Command{ func init() { collectCmd.AddCommand(collectGithubCmd) } -func NewCollectGithubCmd() *cobra.Command { +func GetCollectGithubCmd() *cobra.Command { return collectGithubCmd } diff --git a/cmd/collect_github_release_subcommand.go b/cmd/collect_github_release_subcommand.go index 3022155..bd7cbb5 100644 --- a/cmd/collect_github_release_subcommand.go +++ b/cmd/collect_github_release_subcommand.go @@ -119,6 +119,11 @@ func GetRelease(log *slog.Logger, repoURL string, outputDir string, pack bool, f } else { assetToDownload = release.Assets[0] } + if outputDir != "" { + if err := os.MkdirAll(outputDir, 0755); err != nil { + return nil, fmt.Errorf("failed to create output directory: %w", err) + } + } outputPath := filepath.Join(outputDir, assetToDownload.GetName()) log.Info("downloading asset", "name", assetToDownload.GetName()) data, err := borg_github.DownloadReleaseAsset(assetToDownload) diff --git a/cmd/collect_github_repo_test.go b/cmd/collect_github_repo_test.go index e3745b5..4b23b80 100644 --- a/cmd/collect_github_repo_test.go +++ b/cmd/collect_github_repo_test.go @@ -2,6 +2,7 @@ package cmd import ( "fmt" + "path/filepath" "testing" "github.com/Snider/Borg/pkg/datanode" @@ -22,7 +23,8 @@ func TestCollectGithubRepoCmd_Good(t *testing.T) { rootCmd := NewRootCmd() rootCmd.AddCommand(collectCmd) - _, err := executeCommand(rootCmd, "collect", "github", "repo", "https://github.com/testuser/repo1.git", "--output", "/dev/null") + out := filepath.Join(t.TempDir(), "out") + _, err := executeCommand(rootCmd, "collect", "github", "repo", "https://github.com/testuser/repo1.git", "--output", out) if err != nil { t.Fatalf("collect github repo command failed: %v", err) } @@ -42,7 +44,8 @@ func TestCollectGithubRepoCmd_Bad(t *testing.T) { rootCmd := NewRootCmd() rootCmd.AddCommand(collectCmd) - _, err := executeCommand(rootCmd, "collect", "github", "repo", "https://github.com/testuser/repo1.git", "--output", "/dev/null") + out := filepath.Join(t.TempDir(), "out") + _, err := executeCommand(rootCmd, "collect", "github", "repo", "https://github.com/testuser/repo1.git", "--output", out) if err == nil { t.Fatalf("expected an error, but got none") } diff --git a/cmd/collect_pwa.go b/cmd/collect_pwa.go index af82a75..649516a 100644 --- a/cmd/collect_pwa.go +++ b/cmd/collect_pwa.go @@ -57,6 +57,12 @@ func CollectPWA(client pwa.PWAClient, pwaURL string, outputFile string, format s if pwaURL == "" { return "", fmt.Errorf("uri is required") } + if format != "datanode" && format != "matrix" { + return "", fmt.Errorf("invalid format: %s (must be 'datanode' or 'matrix')", format) + } + if compression != "none" && compression != "gz" && compression != "xz" { + return "", fmt.Errorf("invalid compression: %s (must be 'none', 'gz', or 'xz')", compression) + } bar := ui.NewProgressBar(-1, "Finding PWA manifest") defer bar.Finish() diff --git a/cmd/collect_pwa_test.go b/cmd/collect_pwa_test.go index 0620a09..b92f901 100644 --- a/cmd/collect_pwa_test.go +++ b/cmd/collect_pwa_test.go @@ -2,6 +2,7 @@ package cmd import ( "fmt" + "path/filepath" "strings" "testing" @@ -36,7 +37,9 @@ func TestCollectPWA_Good(t *testing.T) { Err: nil, } - _, err := CollectPWA(mockClient, "https://example.com", "/dev/null", "datanode", "none") + dir := t.TempDir() + path := filepath.Join(dir, "pwa.dat") + _, err := CollectPWA(mockClient, "https://example.com", path, "datanode", "none") if err != nil { t.Fatalf("CollectPWA failed: %v", err) } @@ -49,7 +52,9 @@ func TestCollectPWA_Bad(t *testing.T) { Err: fmt.Errorf("pwa error"), } - _, err := CollectPWA(mockClient, "https://example.com", "/dev/null", "datanode", "none") + dir := t.TempDir() + path := filepath.Join(dir, "pwa.dat") + _, err := CollectPWA(mockClient, "https://example.com", path, "datanode", "none") if err == nil { t.Fatalf("expected an error, but got none") } diff --git a/cmd/collect_website_test.go b/cmd/collect_website_test.go index 3819f08..a89492c 100644 --- a/cmd/collect_website_test.go +++ b/cmd/collect_website_test.go @@ -2,6 +2,7 @@ package cmd import ( "fmt" + "path/filepath" "strings" "testing" @@ -42,7 +43,8 @@ func TestCollectWebsiteCmd_Good(t *testing.T) { rootCmd := NewRootCmd() rootCmd.AddCommand(collectCmd) - _, err := executeCommand(rootCmd, "collect", "website", "https://example.com", "--output", "/dev/null") + out := filepath.Join(t.TempDir(), "out") + _, err := executeCommand(rootCmd, "collect", "website", "https://example.com", "--output", out) if err != nil { t.Fatalf("collect website command failed: %v", err) } @@ -60,7 +62,8 @@ func TestCollectWebsiteCmd_Bad(t *testing.T) { rootCmd := NewRootCmd() rootCmd.AddCommand(collectCmd) - _, err := executeCommand(rootCmd, "collect", "website", "https://example.com", "--output", "/dev/null") + out := filepath.Join(t.TempDir(), "out") + _, err := executeCommand(rootCmd, "collect", "website", "https://example.com", "--output", out) if err == nil { t.Fatalf("expected an error, but got none") } diff --git a/pkg/github/release_test.go b/pkg/github/release_test.go index bb6b6c6..f1d7e5f 100644 --- a/pkg/github/release_test.go +++ b/pkg/github/release_test.go @@ -167,6 +167,9 @@ func TestGetLatestRelease_Error(t *testing.T) { } _, err := GetLatestRelease("owner", "repo") + if err == nil { + t.Fatalf("expected error but got nil") + } if err.Error() != expectedErr { t.Fatalf("GetLatestRelease failed: %v", err) } From 5c65673432cf3171447a61781037b78ee7b116c5 Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Mon, 3 Nov 2025 20:14:47 +0000 Subject: [PATCH 4/4] feat: Bug fixes and refactoring This commit introduces a significant number of bug fixes, refactorings, and improvements to the codebase. Key changes include: - Refactored Cobra commands to use `RunE` for better error handling and testing. - Extracted business logic from command handlers into separate, testable functions. - Added comprehensive unit tests for the `cmd`, `compress`, `github`, `logger`, and `pwa` packages. - Added tests for missing command-line arguments, as requested. - Implemented the `borg all` command to clone all public repositories for a GitHub user or organization. - Restored and improved the `collect pwa` functionality. - Removed duplicate code and fixed various bugs. - Addressed a resource leak in the `all` command. - Improved error handling in the `pwa` package. - Refactored `main.go` to remove duplicated logic. - Fixed several other minor bugs and inconsistencies. - Made tests platform-independent by removing hardcoded `/dev/null` paths. - Fixed potential panics in tests by adding `nil` checks for errors. - Fixed test state leakage by using `t.Cleanup` to restore mocked package-level variables. - Added validation for command-line flags. - Ensured output directories are created before writing files. - Fixed naming conventions for Cobra command constructors. - Consolidated tests for `ParseRepoFromURL`. - Improved test failure messages. - Added validation for release tags. - Aggregated errors during asset downloads. --- cmd/collect_github_release_subcommand.go | 79 ++++++++++++++---------- pkg/github/release_test.go | 23 +------ 2 files changed, 47 insertions(+), 55 deletions(-) diff --git a/cmd/collect_github_release_subcommand.go b/cmd/collect_github_release_subcommand.go index bd7cbb5..9e975a6 100644 --- a/cmd/collect_github_release_subcommand.go +++ b/cmd/collect_github_release_subcommand.go @@ -14,39 +14,39 @@ import ( "golang.org/x/mod/semver" ) -// collectGithubReleaseCmd represents the collect github-release command -var collectGithubReleaseCmd = &cobra.Command{ - Use: "release [repository-url]", - Short: "Download the latest release of a file from GitHub releases", - Long: `Download the latest release of a file from GitHub releases. If the file or URL has a version number, it will check for a higher version and download it if found.`, - Args: cobra.ExactArgs(1), - RunE: func(cmd *cobra.Command, args []string) error { - logVal := cmd.Context().Value("logger") - log, ok := logVal.(*slog.Logger) - if !ok || log == nil { - return errors.New("logger not properly initialised") - } - repoURL := args[0] - outputDir, _ := cmd.Flags().GetString("output") - pack, _ := cmd.Flags().GetBool("pack") - file, _ := cmd.Flags().GetString("file") - version, _ := cmd.Flags().GetString("version") +func NewCollectGithubReleaseCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "release [repository-url]", + Short: "Download the latest release of a file from GitHub releases", + Long: `Download the latest release of a file from GitHub releases. If the file or URL has a version number, it will check for a higher version and download it if found.`, + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + logVal := cmd.Context().Value("logger") + log, ok := logVal.(*slog.Logger) + if !ok || log == nil { + return errors.New("logger not properly initialised") + } + repoURL := args[0] + outputDir, _ := cmd.Flags().GetString("output") + pack, _ := cmd.Flags().GetBool("pack") + file, _ := cmd.Flags().GetString("file") + version, _ := cmd.Flags().GetString("version") - _, err := GetRelease(log, repoURL, outputDir, pack, file, version) - return err - }, + _, err := GetRelease(log, repoURL, outputDir, pack, file, version) + return err + }, + } + cmd.PersistentFlags().String("output", ".", "Output directory for the downloaded file") + cmd.PersistentFlags().Bool("pack", false, "Pack all assets into a DataNode") + cmd.PersistentFlags().String("file", "", "The file to download from the release") + cmd.PersistentFlags().String("version", "", "The version to check against") + return cmd } func init() { - collectGithubCmd.AddCommand(collectGithubReleaseCmd) - collectGithubReleaseCmd.PersistentFlags().String("output", ".", "Output directory for the downloaded file") - collectGithubReleaseCmd.PersistentFlags().Bool("pack", false, "Pack all assets into a DataNode") - collectGithubReleaseCmd.PersistentFlags().String("file", "", "The file to download from the release") - collectGithubReleaseCmd.PersistentFlags().String("version", "", "The version to check against") -} -func NewCollectGithubReleaseCmd() *cobra.Command { - return collectGithubReleaseCmd + collectGithubCmd.AddCommand(NewCollectGithubReleaseCmd()) } + func GetRelease(log *slog.Logger, repoURL string, outputDir string, pack bool, file string, version string) (*github.RepositoryRelease, error) { owner, repo, err := borg_github.ParseRepoFromURL(repoURL) if err != nil { @@ -61,26 +61,37 @@ func GetRelease(log *slog.Logger, repoURL string, outputDir string, pack bool, f log.Info("found latest release", "tag", release.GetTagName()) if version != "" { - if !semver.IsValid(version) { - return nil, fmt.Errorf("invalid version string: %s", version) - } - if semver.Compare(release.GetTagName(), version) <= 0 { - log.Info("latest release is not newer than the provided version", "latest", release.GetTagName(), "provided", version) - return nil, nil + tag := release.GetTagName() + if !semver.IsValid(tag) { + log.Info("latest release tag is not a valid semantic version, skipping comparison", "tag", tag) + } else { + if !semver.IsValid(version) { + return nil, fmt.Errorf("invalid version string: %s", version) + } + if semver.Compare(tag, version) <= 0 { + log.Info("latest release is not newer than the provided version", "latest", tag, "provided", version) + return nil, nil + } } } if pack { dn := datanode.New() + var failedAssets []string for _, asset := range release.Assets { log.Info("downloading asset", "name", asset.GetName()) data, err := borg_github.DownloadReleaseAsset(asset) if err != nil { log.Error("failed to download asset", "name", asset.GetName(), "err", err) + failedAssets = append(failedAssets, asset.GetName()) continue } dn.AddData(asset.GetName(), data) } + if len(failedAssets) > 0 { + return nil, fmt.Errorf("failed to download assets: %v", failedAssets) + } + tar, err := dn.ToTar() if err != nil { return nil, fmt.Errorf("failed to create datanode: %w", err) diff --git a/pkg/github/release_test.go b/pkg/github/release_test.go index f1d7e5f..fa49052 100644 --- a/pkg/github/release_test.go +++ b/pkg/github/release_test.go @@ -32,6 +32,7 @@ func TestParseRepoFromURL(t *testing.T) { {"git@github.com:owner/repo.git", "owner", "repo", false}, {"https://github.com/owner/repo/tree/main", "", "", true}, {"invalid-url", "", "", true}, + {"git:github.com:owner/repo.git", "owner", "repo", false}, } for _, tc := range testCases { @@ -142,7 +143,7 @@ func TestDownloadReleaseAsset_NewRequestError(t *testing.T) { _, err := DownloadReleaseAsset(asset) if err == nil { - t.Fatalf("DownloadReleaseAsset failed: %v", err) + t.Fatalf("expected error but got nil") } } @@ -195,23 +196,3 @@ func TestDownloadReleaseAsset_DoError(t *testing.T) { t.Fatalf("DownloadReleaseAsset should have failed") } } -func TestParseRepoFromURL_More(t *testing.T) { - testCases := []struct { - url string - owner string - repo string - expectErr bool - }{ - {"git:github.com:owner/repo.git", "owner", "repo", false}, - } - - for _, tc := range testCases { - owner, repo, err := ParseRepoFromURL(tc.url) - if (err != nil) != tc.expectErr { - t.Errorf("unexpected error for URL %s: %v", tc.url, err) - } - if owner != tc.owner || repo != tc.repo { - t.Errorf("unexpected owner/repo for URL %s: %s/%s", tc.url, owner, repo) - } - } -}