diff --git a/.gitignore b/.gitignore index 24fef9a..1bd3f87 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,4 @@ borg *.cube .task -.idea \ No newline at end of file +*.datanode diff --git a/TDD/collect_commands_test.go b/TDD/collect_commands_test.go new file mode 100644 index 0000000..70d8e9c --- /dev/null +++ b/TDD/collect_commands_test.go @@ -0,0 +1,145 @@ +package tdd_test + +import ( + "bytes" + "context" + "fmt" + "io" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/Snider/Borg/cmd" + "github.com/Snider/Borg/pkg/datanode" + "github.com/schollz/progressbar/v3" +) + +type mockGitCloner struct{} + +func (m *mockGitCloner) CloneGitRepository(repoURL string, progress io.Writer) (*datanode.DataNode, error) { + dn := datanode.New() + dn.AddData("README.md", []byte("Mock README")) + return dn, nil +} + +type mockGithubClient struct{} + +func (m *mockGithubClient) GetPublicRepos(ctx context.Context, userOrOrg string) ([]string, error) { + return []string{"https://github.com/test/repo1.git"}, nil +} + +type mockPWAClient struct{} + +func (m *mockPWAClient) FindManifest(pageURL string) (string, error) { + return "http://test.com/manifest.json", nil +} + +func (m *mockPWAClient) DownloadAndPackagePWA(baseURL string, manifestURL string, bar *progressbar.ProgressBar) (*datanode.DataNode, error) { + dn := datanode.New() + dn.AddData("manifest.json", []byte(`{"name": "Test PWA", "start_url": "index.html"}`)) + return dn, nil +} + +func TestCollectCommands(t *testing.T) { + // Setup a test server for the website test + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/page2.html" { + fmt.Fprintln(w, `Hello`) + } else { + fmt.Fprintln(w, `Page 2`) + } + })) + defer server.Close() + + testCases := []struct { + name string + args []string + expectedStdout string + expectedStderr string + setup func(t *testing.T) + }{ + { + name: "collect github repos", + args: []string{"collect", "github", "repos", "test"}, + expectedStdout: "https://github.com/test/repo1.git", + setup: func(t *testing.T) { + original := cmd.GithubClient + cmd.GithubClient = &mockGithubClient{} + t.Cleanup(func() { + cmd.GithubClient = original + }) + }, + }, + { + name: "collect github repo", + args: []string{"collect", "github", "repo", "https://github.com/test/repo1.git"}, + expectedStdout: "Repository saved to repo.datanode", + setup: func(t *testing.T) { + original := cmd.GitCloner + cmd.GitCloner = &mockGitCloner{} + t.Cleanup(func() { + cmd.GitCloner = original + }) + }, + }, + { + name: "collect pwa", + args: []string{"collect", "pwa", "--uri", "http://test.com"}, + expectedStdout: "PWA saved to pwa.datanode", + setup: func(t *testing.T) { + original := cmd.PWAClient + cmd.PWAClient = &mockPWAClient{} + t.Cleanup(func() { + cmd.PWAClient = original + }) + }, + }, + { + name: "collect website", + args: []string{"collect", "website", server.URL}, + expectedStdout: "Website saved to website.datanode", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + if tc.setup != nil { + tc.setup(t) + } + + t.Cleanup(func() { + files, err := filepath.Glob("*.datanode") + if err != nil { + t.Fatalf("failed to glob for datanode files: %v", err) + } + for _, f := range files { + if err := os.Remove(f); err != nil { + t.Logf("failed to remove datanode file %s: %v", f, err) + } + } + }) + + rootCmd := cmd.NewRootCmd() + outBuf := new(bytes.Buffer) + errBuf := new(bytes.Buffer) + rootCmd.SetOut(outBuf) + rootCmd.SetErr(errBuf) + rootCmd.SetArgs(tc.args) + + err := rootCmd.ExecuteContext(context.Background()) + if err != nil { + t.Fatal(err) + } + + if tc.expectedStdout != "" && !strings.Contains(outBuf.String(), tc.expectedStdout) { + t.Errorf("expected stdout to contain %q, but got %q", tc.expectedStdout, outBuf.String()) + } + if tc.expectedStderr != "" && !strings.Contains(errBuf.String(), tc.expectedStderr) { + t.Errorf("expected stderr to contain %q, but got %q", tc.expectedStderr, errBuf.String()) + } + }) + } +} diff --git a/cmd/all.go b/cmd/all.go index 7bbd3e9..7f7c782 100644 --- a/cmd/all.go +++ b/cmd/all.go @@ -1,15 +1,12 @@ package cmd import ( - "context" "fmt" "log/slog" "os" "strings" - "github.com/Snider/Borg/pkg/github" "github.com/Snider/Borg/pkg/ui" - "github.com/Snider/Borg/pkg/vcs" "github.com/spf13/cobra" ) @@ -27,7 +24,7 @@ var allCmd = &cobra.Command{ fmt.Fprintln(os.Stderr, "Error: logger not properly initialised") return } - repos, err := github.GetPublicRepos(context.Background(), args[0]) + repos, err := GithubClient.GetPublicRepos(cmd.Context(), args[0]) if err != nil { log.Error("failed to get public repos", "err", err) return @@ -39,7 +36,7 @@ var allCmd = &cobra.Command{ log.Info("cloning repository", "url", repoURL) bar := ui.NewProgressBar(-1, "Cloning repository") - dn, err := vcs.CloneGitRepository(repoURL, bar) + dn, err := GitCloner.CloneGitRepository(repoURL, bar) bar.Finish() if err != nil { log.Error("failed to clone repository", "url", repoURL, "err", err) @@ -63,7 +60,6 @@ var allCmd = &cobra.Command{ }, } -// init registers the 'all' subcommand and its flags. func init() { RootCmd.AddCommand(allCmd) allCmd.PersistentFlags().String("output", ".", "Output directory for the DataNodes") diff --git a/cmd/collect.go b/cmd/collect.go index e655c67..8e3d817 100644 --- a/cmd/collect.go +++ b/cmd/collect.go @@ -11,7 +11,6 @@ var collectCmd = &cobra.Command{ Long: `Collect a resource from a git repository, a website, or other URI and store it in a DataNode.`, } -// init registers the 'collect' command under the root command. func init() { RootCmd.AddCommand(collectCmd) } diff --git a/cmd/collect_github.go b/cmd/collect_github.go index deb754f..58d6f8d 100644 --- a/cmd/collect_github.go +++ b/cmd/collect_github.go @@ -11,7 +11,6 @@ var collectGithubCmd = &cobra.Command{ Long: `Collect a resource from a GitHub repository, such as a repository or a release.`, } -// init registers the 'github' subcommand under the collect command. func init() { collectCmd.AddCommand(collectGithubCmd) } diff --git a/cmd/collect_github_release_subcommand.go b/cmd/collect_github_release_subcommand.go index 7d7a4fa..83565c0 100644 --- a/cmd/collect_github_release_subcommand.go +++ b/cmd/collect_github_release_subcommand.go @@ -126,7 +126,6 @@ var collectGithubReleaseCmd = &cobra.Command{ }, } -// init registers the 'release' subcommand and its flags under the GitHub command. func init() { collectGithubCmd.AddCommand(collectGithubReleaseCmd) collectGithubReleaseCmd.PersistentFlags().String("output", ".", "Output directory for the downloaded file") diff --git a/cmd/collect_github_repo.go b/cmd/collect_github_repo.go index 415b553..f279129 100644 --- a/cmd/collect_github_repo.go +++ b/cmd/collect_github_repo.go @@ -13,6 +13,11 @@ import ( "github.com/spf13/cobra" ) +var ( + // GitCloner is the git cloner used by the command. It can be replaced for testing. + GitCloner = vcs.NewGitCloner() +) + // collectGithubRepoCmd represents the collect github repo command var collectGithubRepoCmd = &cobra.Command{ Use: "repo [repository-url]", @@ -35,9 +40,9 @@ var collectGithubRepoCmd = &cobra.Command{ progressWriter = ui.NewProgressWriter(bar) } - dn, err := vcs.CloneGitRepository(repoURL, progressWriter) + dn, err := GitCloner.CloneGitRepository(repoURL, progressWriter) if err != nil { - fmt.Printf("Error cloning repository: %v\n", err) + fmt.Fprintln(cmd.ErrOrStderr(), "Error cloning repository:", err) return } @@ -45,25 +50,25 @@ var collectGithubRepoCmd = &cobra.Command{ if format == "matrix" { matrix, err := matrix.FromDataNode(dn) if err != nil { - fmt.Printf("Error creating matrix: %v\n", err) + fmt.Fprintln(cmd.ErrOrStderr(), "Error creating matrix:", err) return } data, err = matrix.ToTar() if err != nil { - fmt.Printf("Error serializing matrix: %v\n", err) + fmt.Fprintln(cmd.ErrOrStderr(), "Error serializing matrix:", err) return } } else { data, err = dn.ToTar() if err != nil { - fmt.Printf("Error serializing DataNode: %v\n", err) + fmt.Fprintln(cmd.ErrOrStderr(), "Error serializing DataNode:", err) return } } compressedData, err := compress.Compress(data, compression) if err != nil { - fmt.Printf("Error compressing data: %v\n", err) + fmt.Fprintln(cmd.ErrOrStderr(), "Error compressing data:", err) return } @@ -76,15 +81,14 @@ var collectGithubRepoCmd = &cobra.Command{ err = os.WriteFile(outputFile, compressedData, 0644) if err != nil { - fmt.Printf("Error writing DataNode to file: %v\n", err) + fmt.Fprintln(cmd.ErrOrStderr(), "Error writing DataNode to file:", err) return } - fmt.Printf("Repository saved to %s\n", outputFile) + fmt.Fprintln(cmd.OutOrStdout(), "Repository saved to", outputFile) }, } -// init registers the 'repo' subcommand and its flags under the GitHub command. func init() { collectGithubCmd.AddCommand(collectGithubRepoCmd) collectGithubRepoCmd.PersistentFlags().String("output", "", "Output file for the DataNode") diff --git a/cmd/collect_github_repos.go b/cmd/collect_github_repos.go index d58ebf3..dfcd315 100644 --- a/cmd/collect_github_repos.go +++ b/cmd/collect_github_repos.go @@ -7,24 +7,27 @@ import ( "github.com/spf13/cobra" ) -// collectGithubReposCmd represents the command that lists public repositories for a user or organization. +var ( + // GithubClient is the github client used by the command. It can be replaced for testing. + GithubClient = github.NewGithubClient() +) + var collectGithubReposCmd = &cobra.Command{ Use: "repos [user-or-org]", Short: "Collects all public repositories for a user or organization", Args: cobra.ExactArgs(1), RunE: func(cmd *cobra.Command, args []string) error { - repos, err := github.GetPublicRepos(cmd.Context(), args[0]) + repos, err := GithubClient.GetPublicRepos(cmd.Context(), args[0]) if err != nil { return err } for _, repo := range repos { - fmt.Println(repo) + fmt.Fprintln(cmd.OutOrStdout(), repo) } return nil }, } -// init registers the collectGithubReposCmd subcommand under the GitHub command. func init() { collectGithubCmd.AddCommand(collectGithubReposCmd) } diff --git a/cmd/collect_pwa.go b/cmd/collect_pwa.go index 878064a..06e06f5 100644 --- a/cmd/collect_pwa.go +++ b/cmd/collect_pwa.go @@ -12,6 +12,11 @@ import ( "github.com/spf13/cobra" ) +var ( + // PWAClient is the pwa client used by the command. It can be replaced for testing. + PWAClient = pwa.NewPWAClient() +) + // collectPWACmd represents the collect pwa command var collectPWACmd = &cobra.Command{ Use: "pwa", @@ -27,22 +32,22 @@ Example: compression, _ := cmd.Flags().GetString("compression") if pwaURL == "" { - fmt.Println("Error: uri is required") + fmt.Fprintln(cmd.ErrOrStderr(), "Error: uri is required") return } bar := ui.NewProgressBar(-1, "Finding PWA manifest") defer bar.Finish() - manifestURL, err := pwa.FindManifest(pwaURL) + manifestURL, err := PWAClient.FindManifest(pwaURL) if err != nil { - fmt.Printf("Error finding manifest: %v\n", err) + fmt.Fprintln(cmd.ErrOrStderr(), "Error finding manifest:", err) return } bar.Describe("Downloading and packaging PWA") - dn, err := pwa.DownloadAndPackagePWA(pwaURL, manifestURL, bar) + dn, err := PWAClient.DownloadAndPackagePWA(pwaURL, manifestURL, bar) if err != nil { - fmt.Printf("Error downloading and packaging PWA: %v\n", err) + fmt.Fprintln(cmd.ErrOrStderr(), "Error downloading and packaging PWA:", err) return } @@ -50,25 +55,25 @@ Example: if format == "matrix" { matrix, err := matrix.FromDataNode(dn) if err != nil { - fmt.Printf("Error creating matrix: %v\n", err) + fmt.Fprintln(cmd.ErrOrStderr(), "Error creating matrix:", err) return } data, err = matrix.ToTar() if err != nil { - fmt.Printf("Error serializing matrix: %v\n", err) + fmt.Fprintln(cmd.ErrOrStderr(), "Error serializing matrix:", err) return } } else { data, err = dn.ToTar() if err != nil { - fmt.Printf("Error serializing DataNode: %v\n", err) + fmt.Fprintln(cmd.ErrOrStderr(), "Error serializing DataNode:", err) return } } compressedData, err := compress.Compress(data, compression) if err != nil { - fmt.Printf("Error compressing data: %v\n", err) + fmt.Fprintln(cmd.ErrOrStderr(), "Error compressing data:", err) return } @@ -81,15 +86,14 @@ Example: err = os.WriteFile(outputFile, compressedData, 0644) if err != nil { - fmt.Printf("Error writing PWA to file: %v\n", err) + fmt.Fprintln(cmd.ErrOrStderr(), "Error writing PWA to file:", err) return } - fmt.Printf("PWA saved to %s\n", outputFile) + fmt.Fprintln(cmd.OutOrStdout(), "PWA saved to", outputFile) }, } -// init registers the 'pwa' command and its flags under the collect command. func init() { collectCmd.AddCommand(collectPWACmd) collectPWACmd.Flags().String("uri", "", "The URI of the PWA to collect") diff --git a/cmd/collect_website.go b/cmd/collect_website.go index 80f354a..6b67a3b 100644 --- a/cmd/collect_website.go +++ b/cmd/collect_website.go @@ -4,11 +4,11 @@ import ( "fmt" "os" + "github.com/schollz/progressbar/v3" "github.com/Snider/Borg/pkg/compress" "github.com/Snider/Borg/pkg/matrix" "github.com/Snider/Borg/pkg/ui" "github.com/Snider/Borg/pkg/website" - "github.com/schollz/progressbar/v3" "github.com/spf13/cobra" ) @@ -36,7 +36,7 @@ var collectWebsiteCmd = &cobra.Command{ dn, err := website.DownloadAndPackageWebsite(websiteURL, depth, bar) if err != nil { - fmt.Printf("Error downloading and packaging website: %v\n", err) + fmt.Fprintln(cmd.ErrOrStderr(), "Error downloading and packaging website:", err) return } @@ -44,25 +44,25 @@ var collectWebsiteCmd = &cobra.Command{ if format == "matrix" { matrix, err := matrix.FromDataNode(dn) if err != nil { - fmt.Printf("Error creating matrix: %v\n", err) + fmt.Fprintln(cmd.ErrOrStderr(), "Error creating matrix:", err) return } data, err = matrix.ToTar() if err != nil { - fmt.Printf("Error serializing matrix: %v\n", err) + fmt.Fprintln(cmd.ErrOrStderr(), "Error serializing matrix:", err) return } } else { data, err = dn.ToTar() if err != nil { - fmt.Printf("Error serializing DataNode: %v\n", err) + fmt.Fprintln(cmd.ErrOrStderr(), "Error serializing DataNode:", err) return } } compressedData, err := compress.Compress(data, compression) if err != nil { - fmt.Printf("Error compressing data: %v\n", err) + fmt.Fprintln(cmd.ErrOrStderr(), "Error compressing data:", err) return } @@ -75,15 +75,14 @@ var collectWebsiteCmd = &cobra.Command{ err = os.WriteFile(outputFile, compressedData, 0644) if err != nil { - fmt.Printf("Error writing website to file: %v\n", err) + fmt.Fprintln(cmd.ErrOrStderr(), "Error writing website to file:", err) return } - fmt.Printf("Website saved to %s\n", outputFile) + fmt.Fprintln(cmd.OutOrStdout(), "Website saved to", outputFile) }, } -// init registers the 'website' command and its flags under the collect command. func init() { collectCmd.AddCommand(collectWebsiteCmd) collectWebsiteCmd.PersistentFlags().String("output", "", "Output file for the DataNode") diff --git a/cmd/root.go b/cmd/root.go index d22dbc9..013d430 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -7,22 +7,26 @@ import ( "github.com/spf13/cobra" ) -// RootCmd represents the base command when called without any subcommands -var RootCmd = &cobra.Command{ - Use: "borg-data-collector", - Short: "A tool for collecting and managing data.", - Long: `Borg Data Collector is a command-line tool for cloning Git repositories, +func NewRootCmd() *cobra.Command { + rootCmd := &cobra.Command{ + Use: "borg-data-collector", + Short: "A tool for collecting and managing data.", + Long: `Borg Data Collector is a command-line tool for cloning Git repositories, packaging their contents into a single file, and managing the data within.`, + } + rootCmd.AddCommand(allCmd) + rootCmd.AddCommand(collectCmd) + rootCmd.AddCommand(serveCmd) + rootCmd.PersistentFlags().BoolP("verbose", "v", false, "Enable verbose logging") + return rootCmd } +// RootCmd represents the base command when called without any subcommands +var RootCmd = NewRootCmd() + // Execute adds all child commands to the root command and sets flags appropriately. // This is called by main.main(). It only needs to happen once to the rootCmd. func Execute(log *slog.Logger) error { RootCmd.SetContext(context.WithValue(context.Background(), "logger", log)) return RootCmd.Execute() } - -// init configures persistent flags for the root command. -func init() { - RootCmd.PersistentFlags().BoolP("verbose", "v", false, "Enable verbose logging") -} diff --git a/cmd/serve.go b/cmd/serve.go index c183ec5..87e225f 100644 --- a/cmd/serve.go +++ b/cmd/serve.go @@ -62,7 +62,6 @@ var serveCmd = &cobra.Command{ }, } -// init registers the 'serve' command and its flags under the root command. func init() { RootCmd.AddCommand(serveCmd) serveCmd.PersistentFlags().String("port", "8080", "Port to serve the PWA on") diff --git a/main.go b/main.go index b3fadd5..7c02e17 100644 --- a/main.go +++ b/main.go @@ -7,7 +7,6 @@ import ( "github.com/Snider/Borg/pkg/logger" ) -// main is the entry point of the application, initialises logger, and executes the root command with error handling. func main() { verbose, _ := cmd.RootCmd.PersistentFlags().GetBool("verbose") log := logger.New(verbose) diff --git a/pkg/datanode/datanode.go b/pkg/datanode/datanode.go index 9352f3e..fe2f43b 100644 --- a/pkg/datanode/datanode.go +++ b/pkg/datanode/datanode.go @@ -260,35 +260,19 @@ type dataFile struct { modTime time.Time } -// Stat returns a FileInfo describing the dataFile. func (d *dataFile) Stat() (fs.FileInfo, error) { return &dataFileInfo{file: d}, nil } - -// Read implements fs.File by returning EOF for write-only dataFile handles. func (d *dataFile) Read(p []byte) (int, error) { return 0, io.EOF } - -// Close is a no-op for in-memory dataFile values. -func (d *dataFile) Close() error { return nil } +func (d *dataFile) Close() error { return nil } // dataFileInfo implements fs.FileInfo for a dataFile. type dataFileInfo struct{ file *dataFile } -// Name returns the base name of the data file. -func (d *dataFileInfo) Name() string { return path.Base(d.file.name) } - -// Size returns the size of the data file in bytes. -func (d *dataFileInfo) Size() int64 { return int64(len(d.file.content)) } - -// Mode returns the file mode bits for a read-only regular file. -func (d *dataFileInfo) Mode() fs.FileMode { return 0444 } - -// ModTime returns the modification time of the data file. +func (d *dataFileInfo) Name() string { return path.Base(d.file.name) } +func (d *dataFileInfo) Size() int64 { return int64(len(d.file.content)) } +func (d *dataFileInfo) Mode() fs.FileMode { return 0444 } func (d *dataFileInfo) ModTime() time.Time { return d.file.modTime } - -// IsDir reports whether the FileInfo describes a directory (always false). -func (d *dataFileInfo) IsDir() bool { return false } - -// Sys returns underlying data source (always nil). -func (d *dataFileInfo) Sys() interface{} { return nil } +func (d *dataFileInfo) IsDir() bool { return false } +func (d *dataFileInfo) Sys() interface{} { return nil } // dataFileReader implements fs.File for a dataFile. type dataFileReader struct { @@ -296,18 +280,13 @@ type dataFileReader struct { reader *bytes.Reader } -// Stat returns a FileInfo describing the underlying data file. func (d *dataFileReader) Stat() (fs.FileInfo, error) { return d.file.Stat() } - -// Read reads from the underlying byte slice, initializing the reader on first use. func (d *dataFileReader) Read(p []byte) (int, error) { if d.reader == nil { d.reader = bytes.NewReader(d.file.content) } return d.reader.Read(p) } - -// Close is a no-op for in-memory readers. func (d *dataFileReader) Close() error { return nil } // dirInfo implements fs.FileInfo for an implicit directory. @@ -316,23 +295,12 @@ type dirInfo struct { modTime time.Time } -// Name returns the directory name. -func (d *dirInfo) Name() string { return d.name } - -// Size returns the size for a directory (always 0). -func (d *dirInfo) Size() int64 { return 0 } - -// Mode returns the file mode bits indicating a read-only directory. -func (d *dirInfo) Mode() fs.FileMode { return fs.ModeDir | 0555 } - -// ModTime returns the modification time of the directory. +func (d *dirInfo) Name() string { return d.name } +func (d *dirInfo) Size() int64 { return 0 } +func (d *dirInfo) Mode() fs.FileMode { return fs.ModeDir | 0555 } func (d *dirInfo) ModTime() time.Time { return d.modTime } - -// IsDir reports that this FileInfo describes a directory. -func (d *dirInfo) IsDir() bool { return true } - -// Sys returns underlying data source (always nil). -func (d *dirInfo) Sys() interface{} { return nil } +func (d *dirInfo) IsDir() bool { return true } +func (d *dirInfo) Sys() interface{} { return nil } // dirFile implements fs.File for a directory. type dirFile struct { diff --git a/pkg/github/github.go b/pkg/github/github.go index 4542033..022e255 100644 --- a/pkg/github/github.go +++ b/pkg/github/github.go @@ -11,20 +11,27 @@ import ( "golang.org/x/oauth2" ) -// Repo is a minimal representation of a GitHub repository used in this package. type Repo struct { CloneURL string `json:"clone_url"` } -// GetPublicRepos returns clone URLs for all public repositories owned by the given user or org. -// It uses the public GitHub API endpoint. -func GetPublicRepos(ctx context.Context, userOrOrg string) ([]string, error) { - return GetPublicReposWithAPIURL(ctx, "https://api.github.com", userOrOrg) +// GithubClient is an interface for interacting with the Github API. +type GithubClient interface { + GetPublicRepos(ctx context.Context, userOrOrg string) ([]string, error) } -// newAuthenticatedClient returns an HTTP client authenticated with a GitHub token if present. -// If the GITHUB_TOKEN environment variable is not set, it returns http.DefaultClient. -func newAuthenticatedClient(ctx context.Context) *http.Client { +// NewGithubClient creates a new GithubClient. +func NewGithubClient() GithubClient { + return &githubClient{} +} + +type githubClient struct{} + +func (g *githubClient) GetPublicRepos(ctx context.Context, userOrOrg string) ([]string, error) { + return g.getPublicReposWithAPIURL(ctx, "https://api.github.com", userOrOrg) +} + +func (g *githubClient) newAuthenticatedClient(ctx context.Context) *http.Client { token := os.Getenv("GITHUB_TOKEN") if token == "" { return http.DefaultClient @@ -35,10 +42,8 @@ func newAuthenticatedClient(ctx context.Context) *http.Client { return oauth2.NewClient(ctx, ts) } -// GetPublicReposWithAPIURL returns clone URLs for all public repositories for userOrOrg -// using the specified GitHub API base URL. It transparently follows pagination. -func GetPublicReposWithAPIURL(ctx context.Context, apiURL, userOrOrg string) ([]string, error) { - client := newAuthenticatedClient(ctx) +func (g *githubClient) getPublicReposWithAPIURL(ctx context.Context, apiURL, userOrOrg string) ([]string, error) { + client := g.newAuthenticatedClient(ctx) var allCloneURLs []string url := fmt.Sprintf("%s/users/%s/repos", apiURL, userOrOrg) @@ -91,7 +96,7 @@ func GetPublicReposWithAPIURL(ctx context.Context, apiURL, userOrOrg string) ([] if linkHeader == "" { break } - nextURL := findNextURL(linkHeader) + nextURL := g.findNextURL(linkHeader) if nextURL == "" { break } @@ -101,8 +106,7 @@ func GetPublicReposWithAPIURL(ctx context.Context, apiURL, userOrOrg string) ([] return allCloneURLs, nil } -// findNextURL parses the RFC 5988 Link header and returns the URL with rel="next", if any. -func findNextURL(linkHeader string) string { +func (g *githubClient) findNextURL(linkHeader string) string { links := strings.Split(linkHeader, ",") for _, link := range links { parts := strings.Split(link, ";") diff --git a/pkg/logger/logger.go b/pkg/logger/logger.go index e5021b3..0dfc2d2 100644 --- a/pkg/logger/logger.go +++ b/pkg/logger/logger.go @@ -5,8 +5,6 @@ import ( "os" ) -// New returns a configured slog.Logger. -// When verbose is true, the logger emits debug-level logs; otherwise info-level. func New(verbose bool) *slog.Logger { level := slog.LevelInfo if verbose { diff --git a/pkg/mocks/http.go b/pkg/mocks/http.go new file mode 100644 index 0000000..f47fe67 --- /dev/null +++ b/pkg/mocks/http.go @@ -0,0 +1,80 @@ +package mocks + +import ( + "bytes" + "io" + "net/http" + "sync" +) + +// MockRoundTripper is a mock implementation of http.RoundTripper. +type MockRoundTripper struct { + mu sync.RWMutex + responses map[string]*http.Response +} + +// SetResponses sets the mock responses in a thread-safe way. +func (m *MockRoundTripper) SetResponses(responses map[string]*http.Response) { + m.mu.Lock() + defer m.mu.Unlock() + m.responses = responses +} + +// RoundTrip implements the http.RoundTripper interface. +func (m *MockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + url := req.URL.String() + m.mu.RLock() + resp, ok := m.responses[url] + m.mu.RUnlock() + + if ok { + // Read the original body + bodyBytes, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + resp.Body.Close() // close original body + + // Re-hydrate the original body so it can be read again + resp.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + + // Create a deep copy of the response + newResp := &http.Response{ + Status: resp.Status, + StatusCode: resp.StatusCode, + Proto: resp.Proto, + ProtoMajor: resp.ProtoMajor, + ProtoMinor: resp.ProtoMinor, + Header: resp.Header.Clone(), + Body: io.NopCloser(bytes.NewReader(bodyBytes)), + ContentLength: resp.ContentLength, + TransferEncoding: resp.TransferEncoding, + Close: resp.Close, + Uncompressed: resp.Uncompressed, + Trailer: resp.Trailer.Clone(), + Request: resp.Request, + TLS: resp.TLS, + } + return newResp, nil + } + return &http.Response{ + StatusCode: http.StatusNotFound, + Body: io.NopCloser(bytes.NewBufferString("Not Found")), + Header: make(http.Header), + }, nil +} + +// NewMockClient creates a new http.Client with a MockRoundTripper. +func NewMockClient(responses map[string]*http.Response) *http.Client { + responsesCopy := make(map[string]*http.Response) + if responses != nil { + for k, v := range responses { + responsesCopy[k] = v + } + } + return &http.Client{ + Transport: &MockRoundTripper{ + responses: responsesCopy, + }, + } +} diff --git a/pkg/pwa/pwa.go b/pkg/pwa/pwa.go index b0f34a1..0d72345 100644 --- a/pkg/pwa/pwa.go +++ b/pkg/pwa/pwa.go @@ -29,9 +29,26 @@ type Icon struct { Type string `json:"type"` } +// PWAClient is an interface for finding and downloading PWAs. +type PWAClient interface { + FindManifest(pageURL string) (string, error) + DownloadAndPackagePWA(baseURL string, manifestURL string, bar *progressbar.ProgressBar) (*datanode.DataNode, error) +} + +// NewPWAClient creates a new PWAClient. +func NewPWAClient() PWAClient { + return &pwaClient{ + client: http.DefaultClient, + } +} + +type pwaClient struct { + client *http.Client +} + // FindManifest finds the manifest URL from a given HTML page. -func FindManifest(pageURL string) (string, error) { - resp, err := http.Get(pageURL) +func (p *pwaClient) FindManifest(pageURL string) (string, error) { + resp, err := p.client.Get(pageURL) if err != nil { return "", err } @@ -72,7 +89,7 @@ func FindManifest(pageURL string) (string, error) { return "", fmt.Errorf("manifest not found") } - resolvedURL, err := resolveURL(pageURL, manifestPath) + resolvedURL, err := p.resolveURL(pageURL, manifestPath) if err != nil { return "", fmt.Errorf("could not resolve manifest URL: %w", err) } @@ -81,16 +98,16 @@ func FindManifest(pageURL string) (string, error) { } // DownloadAndPackagePWA downloads all assets of a PWA and packages them into a DataNode. -func DownloadAndPackagePWA(baseURL string, manifestURL string, bar *progressbar.ProgressBar) (*datanode.DataNode, error) { +func (p *pwaClient) DownloadAndPackagePWA(baseURL string, manifestURL string, bar *progressbar.ProgressBar) (*datanode.DataNode, error) { if bar == nil { return nil, fmt.Errorf("progress bar cannot be nil") } - manifestAbsURL, err := resolveURL(baseURL, manifestURL) + manifestAbsURL, err := p.resolveURL(baseURL, manifestURL) if err != nil { return nil, fmt.Errorf("could not resolve manifest URL: %w", err) } - resp, err := http.Get(manifestAbsURL.String()) + resp, err := p.client.Get(manifestAbsURL.String()) if err != nil { return nil, fmt.Errorf("could not download manifest: %w", err) } @@ -110,30 +127,30 @@ func DownloadAndPackagePWA(baseURL string, manifestURL string, bar *progressbar. dn.AddData("manifest.json", manifestBody) if manifest.StartURL != "" { - startURLAbs, err := resolveURL(manifestAbsURL.String(), manifest.StartURL) + startURLAbs, err := p.resolveURL(manifestAbsURL.String(), manifest.StartURL) if err != nil { return nil, fmt.Errorf("could not resolve start_url: %w", err) } - err = downloadAndAddFile(dn, startURLAbs, manifest.StartURL, bar) + err = p.downloadAndAddFile(dn, startURLAbs, manifest.StartURL, bar) if err != nil { return nil, fmt.Errorf("failed to download start_url asset: %w", err) } } for _, icon := range manifest.Icons { - iconURLAbs, err := resolveURL(manifestAbsURL.String(), icon.Src) + iconURLAbs, err := p.resolveURL(manifestAbsURL.String(), icon.Src) if err != nil { fmt.Printf("Warning: could not resolve icon URL %s: %v\n", icon.Src, err) continue } - err = downloadAndAddFile(dn, iconURLAbs, icon.Src, bar) + err = p.downloadAndAddFile(dn, iconURLAbs, icon.Src, bar) if err != nil { fmt.Printf("Warning: failed to download icon %s: %v\n", icon.Src, err) } } baseURLAbs, _ := url.Parse(baseURL) - err = downloadAndAddFile(dn, baseURLAbs, "index.html", bar) + err = p.downloadAndAddFile(dn, baseURLAbs, "index.html", bar) if err != nil { return nil, fmt.Errorf("failed to download base HTML: %w", err) } @@ -141,8 +158,7 @@ func DownloadAndPackagePWA(baseURL string, manifestURL string, bar *progressbar. return dn, nil } -// resolveURL resolves ref against base and returns the absolute URL. -func resolveURL(base, ref string) (*url.URL, error) { +func (p *pwaClient) resolveURL(base, ref string) (*url.URL, error) { baseURL, err := url.Parse(base) if err != nil { return nil, err @@ -154,9 +170,8 @@ func resolveURL(base, ref string) (*url.URL, error) { return baseURL.ResolveReference(refURL), nil } -// downloadAndAddFile downloads the content at fileURL and adds it to the DataNode under internalPath. -func downloadAndAddFile(dn *datanode.DataNode, fileURL *url.URL, internalPath string, bar *progressbar.ProgressBar) error { - resp, err := http.Get(fileURL.String()) +func (p *pwaClient) downloadAndAddFile(dn *datanode.DataNode, fileURL *url.URL, internalPath string, bar *progressbar.ProgressBar) error { + resp, err := p.client.Get(fileURL.String()) if err != nil { return err } diff --git a/pkg/pwa/pwa_test.go b/pkg/pwa/pwa_test.go index 81728c8..4bf2315 100644 --- a/pkg/pwa/pwa_test.go +++ b/pkg/pwa/pwa_test.go @@ -1,13 +1,34 @@ package pwa import ( + "net" "net/http" "net/http/httptest" "testing" + "time" "github.com/schollz/progressbar/v3" ) +func newTestPWAClient(serverURL string) PWAClient { + return &pwaClient{ + client: &http.Client{ + Transport: &http.Transport{ + Proxy: http.ProxyFromEnvironment, + DialContext: (&net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + }).DialContext, + ForceAttemptHTTP2: true, + MaxIdleConns: 100, + IdleConnTimeout: 90 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + }, + }, + } +} + func TestFindManifest(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/html") @@ -26,8 +47,9 @@ func TestFindManifest(t *testing.T) { })) defer server.Close() + client := newTestPWAClient(server.URL) expectedURL := server.URL + "/manifest.json" - actualURL, err := FindManifest(server.URL) + actualURL, err := client.FindManifest(server.URL) if err != nil { t.Fatalf("FindManifest failed: %v", err) } @@ -80,8 +102,9 @@ func TestDownloadAndPackagePWA(t *testing.T) { })) defer server.Close() + client := newTestPWAClient(server.URL) bar := progressbar.New(1) - dn, err := DownloadAndPackagePWA(server.URL, server.URL+"/manifest.json", bar) + dn, err := client.DownloadAndPackagePWA(server.URL, server.URL+"/manifest.json", bar) if err != nil { t.Fatalf("DownloadAndPackagePWA failed: %v", err) } @@ -99,6 +122,7 @@ func TestDownloadAndPackagePWA(t *testing.T) { } func TestResolveURL(t *testing.T) { + client := NewPWAClient() tests := []struct { base string ref string @@ -113,7 +137,7 @@ func TestResolveURL(t *testing.T) { } for _, tt := range tests { - got, err := resolveURL(tt.base, tt.ref) + got, err := client.(*pwaClient).resolveURL(tt.base, tt.ref) if err != nil { t.Errorf("resolveURL(%q, %q) returned error: %v", tt.base, tt.ref, err) continue diff --git a/pkg/tarfs/tarfs.go b/pkg/tarfs/tarfs.go index a2a4c2b..6abbee4 100644 --- a/pkg/tarfs/tarfs.go +++ b/pkg/tarfs/tarfs.go @@ -67,23 +67,16 @@ type tarFile struct { modTime time.Time } -// Close implements http.File by doing nothing for a tar-backed file. -func (f *tarFile) Close() error { return nil } - -// Read reads from the tar-backed file content. +func (f *tarFile) Close() error { return nil } func (f *tarFile) Read(p []byte) (int, error) { return f.content.Read(p) } - -// Seek repositions the read offset within the tar-backed file content. func (f *tarFile) Seek(offset int64, whence int) (int64, error) { return f.content.Seek(offset, whence) } -// Readdir is unsupported for files and returns an error. func (f *tarFile) Readdir(count int) ([]os.FileInfo, error) { return nil, os.ErrInvalid } -// Stat returns a FileInfo describing the tar-backed file. func (f *tarFile) Stat() (os.FileInfo, error) { return &tarFileInfo{ name: path.Base(f.header.Name), @@ -99,20 +92,9 @@ type tarFileInfo struct { modTime time.Time } -// Name returns the base name of the tar-backed file. -func (i *tarFileInfo) Name() string { return i.name } - -// Size returns the size of the tar-backed file in bytes. -func (i *tarFileInfo) Size() int64 { return i.size } - -// Mode returns the file mode bits for a read-only regular file. -func (i *tarFileInfo) Mode() os.FileMode { return 0444 } - -// ModTime returns the file's modification time. +func (i *tarFileInfo) Name() string { return i.name } +func (i *tarFileInfo) Size() int64 { return i.size } +func (i *tarFileInfo) Mode() os.FileMode { return 0444 } func (i *tarFileInfo) ModTime() time.Time { return i.modTime } - -// IsDir reports whether the FileInfo describes a directory (always false). -func (i *tarFileInfo) IsDir() bool { return false } - -// Sys returns underlying data source (always nil). -func (i *tarFileInfo) Sys() interface{} { return nil } +func (i *tarFileInfo) IsDir() bool { return false } +func (i *tarFileInfo) Sys() interface{} { return nil } diff --git a/pkg/ui/non_interactive_prompter.go b/pkg/ui/non_interactive_prompter.go index 1c9492f..3eb874f 100644 --- a/pkg/ui/non_interactive_prompter.go +++ b/pkg/ui/non_interactive_prompter.go @@ -1,3 +1,4 @@ + package ui import ( @@ -10,24 +11,21 @@ import ( "github.com/mattn/go-isatty" ) -// NonInteractivePrompter periodically prints quotes when stdout is non-interactive. type NonInteractivePrompter struct { - stopChan chan struct{} - quoteFunc func() (string, error) - started bool - mu sync.Mutex - stopOnce sync.Once + stopChan chan struct{} + quoteFunc func() (string, error) + started bool + mu sync.Mutex + stopOnce sync.Once } -// NewNonInteractivePrompter constructs a NonInteractivePrompter using the provided quote function. func NewNonInteractivePrompter(quoteFunc func() (string, error)) *NonInteractivePrompter { return &NonInteractivePrompter{ - stopChan: make(chan struct{}), - quoteFunc: quoteFunc, + stopChan: make(chan struct{}), + quoteFunc: quoteFunc, } } -// Start begins periodic quote printing in a background goroutine when not interactive. func (p *NonInteractivePrompter) Start() { p.mu.Lock() if p.started { @@ -62,7 +60,6 @@ func (p *NonInteractivePrompter) Start() { }() } -// Stop signals the background goroutine to stop printing quotes. func (p *NonInteractivePrompter) Stop() { if p.IsInteractive() { return @@ -72,7 +69,6 @@ func (p *NonInteractivePrompter) Stop() { }) } -// IsInteractive reports whether stdout is attached to an interactive terminal. func (p *NonInteractivePrompter) IsInteractive() bool { return isatty.IsTerminal(os.Stdout.Fd()) || isatty.IsCygwinTerminal(os.Stdout.Fd()) } diff --git a/pkg/ui/progress_writer.go b/pkg/ui/progress_writer.go index 11798f5..b46b51b 100644 --- a/pkg/ui/progress_writer.go +++ b/pkg/ui/progress_writer.go @@ -1,19 +1,17 @@ + package ui import "github.com/schollz/progressbar/v3" -// ProgressWriter updates a progress bar’s description on writes. -type ProgressWriter struct { +type progressWriter struct { bar *progressbar.ProgressBar } -// NewProgressWriter creates a writer that sets the progress bar description to the last written line. -func NewProgressWriter(bar *progressbar.ProgressBar) *ProgressWriter { - return &ProgressWriter{bar: bar} +func NewProgressWriter(bar *progressbar.ProgressBar) *progressWriter { + return &progressWriter{bar: bar} } -// Write implements io.Writer by describing the progress with the provided bytes. -func (pw *ProgressWriter) Write(p []byte) (n int, err error) { +func (pw *progressWriter) Write(p []byte) (n int, err error) { if pw == nil || pw.bar == nil { return len(p), nil } diff --git a/pkg/ui/quote.go b/pkg/ui/quote.go index ef191f2..3a182cc 100644 --- a/pkg/ui/quote.go +++ b/pkg/ui/quote.go @@ -1,3 +1,4 @@ + package ui import ( @@ -17,12 +18,10 @@ var ( quotesErr error ) -// init seeds the random number generator for quote selection. func init() { rand.Seed(time.Now().UnixNano()) } -// Quotes contains categorized sets of quotes used by the UI. type Quotes struct { InitWorkAssimilate []string `json:"init_work_assimilate"` EncryptionServiceMessages []string `json:"encryption_service_messages"` @@ -44,7 +43,6 @@ type Quotes struct { } `json:"image_related"` } -// loadQuotes reads and unmarshals the embedded quotes JSON file. func loadQuotes() (*Quotes, error) { quotesFile, err := data.QuotesJSON.ReadFile("quotes.json") if err != nil { @@ -58,7 +56,6 @@ func loadQuotes() (*Quotes, error) { return "es, nil } -// getQuotes loads and caches the Quotes on first use, returning the cached instance thereafter. func getQuotes() (*Quotes, error) { quotesOnce.Do(func() { cachedQuotes, quotesErr = loadQuotes() @@ -66,7 +63,6 @@ func getQuotes() (*Quotes, error) { return cachedQuotes, quotesErr } -// GetRandomQuote returns a randomly selected quote from all categories. func GetRandomQuote() (string, error) { quotes, err := getQuotes() if err != nil { @@ -88,7 +84,6 @@ func GetRandomQuote() (string, error) { return allQuotes[rand.Intn(len(allQuotes))], nil } -// PrintQuote prints a randomly selected quote to stdout in green. func PrintQuote() { quote, err := GetRandomQuote() if err != nil { @@ -99,7 +94,6 @@ func PrintQuote() { c.Println(quote) } -// GetVCSQuote returns a random quote from the VCSProcessing category. func GetVCSQuote() (string, error) { quotes, err := getQuotes() if err != nil { @@ -111,7 +105,6 @@ func GetVCSQuote() (string, error) { return quotes.VCSProcessing[rand.Intn(len(quotes.VCSProcessing))], nil } -// GetPWAQuote returns a random quote from the PWAProcessing category. func GetPWAQuote() (string, error) { quotes, err := getQuotes() if err != nil { @@ -123,7 +116,6 @@ func GetPWAQuote() (string, error) { return quotes.PWAProcessing[rand.Intn(len(quotes.PWAProcessing))], nil } -// GetWebsiteQuote returns a random quote from the CodeRelatedLong category. func GetWebsiteQuote() (string, error) { quotes, err := getQuotes() if err != nil { diff --git a/pkg/vcs/git.go b/pkg/vcs/git.go index 2d60cad..fcf10e1 100644 --- a/pkg/vcs/git.go +++ b/pkg/vcs/git.go @@ -10,8 +10,20 @@ import ( "github.com/go-git/go-git/v5" ) +// GitCloner is an interface for cloning Git repositories. +type GitCloner interface { + CloneGitRepository(repoURL string, progress io.Writer) (*datanode.DataNode, error) +} + +// NewGitCloner creates a new GitCloner. +func NewGitCloner() GitCloner { + return &gitCloner{} +} + +type gitCloner struct{} + // CloneGitRepository clones a Git repository from a URL and packages it into a DataNode. -func CloneGitRepository(repoURL string, progress io.Writer) (*datanode.DataNode, error) { +func (g *gitCloner) CloneGitRepository(repoURL string, progress io.Writer) (*datanode.DataNode, error) { tempPath, err := os.MkdirTemp("", "borg-clone-*") if err != nil { return nil, err diff --git a/pkg/vcs/git_test.go b/pkg/vcs/git_test.go index d36e310..fb662ac 100644 --- a/pkg/vcs/git_test.go +++ b/pkg/vcs/git_test.go @@ -69,7 +69,8 @@ func TestCloneGitRepository(t *testing.T) { } // Clone the repository using the function we're testing - dn, err := CloneGitRepository("file://"+bareRepoPath, os.Stdout) + cloner := NewGitCloner() + dn, err := cloner.CloneGitRepository("file://"+bareRepoPath, os.Stdout) if err != nil { t.Fatalf("CloneGitRepository failed: %v", err) } diff --git a/pkg/website/website.go b/pkg/website/website.go index ceff1e2..1780c0e 100644 --- a/pkg/website/website.go +++ b/pkg/website/website.go @@ -20,14 +20,21 @@ type Downloader struct { visited map[string]bool maxDepth int progressBar *progressbar.ProgressBar + client *http.Client } // NewDownloader creates a new Downloader. func NewDownloader(maxDepth int) *Downloader { + return NewDownloaderWithClient(maxDepth, http.DefaultClient) +} + +// NewDownloaderWithClient creates a new Downloader with a custom http.Client. +func NewDownloaderWithClient(maxDepth int, client *http.Client) *Downloader { return &Downloader{ dn: datanode.New(), visited: make(map[string]bool), maxDepth: maxDepth, + client: client, } } @@ -46,7 +53,6 @@ func DownloadAndPackageWebsite(startURL string, maxDepth int, bar *progressbar.P return d.dn, nil } -// crawl visits pageURL, saves its content, and follows local links up to maxDepth. func (d *Downloader) crawl(pageURL string, depth int) { if depth > d.maxDepth || d.visited[pageURL] { return @@ -56,7 +62,7 @@ func (d *Downloader) crawl(pageURL string, depth int) { d.progressBar.Add(1) } - resp, err := http.Get(pageURL) + resp, err := d.client.Get(pageURL) if err != nil { fmt.Printf("Error getting %s: %v\n", pageURL, err) return @@ -104,7 +110,6 @@ func (d *Downloader) crawl(pageURL string, depth int) { f(doc) } -// downloadAsset fetches an asset by URL and stores it in the DataNode. func (d *Downloader) downloadAsset(assetURL string) { if d.visited[assetURL] { return @@ -114,7 +119,7 @@ func (d *Downloader) downloadAsset(assetURL string) { d.progressBar.Add(1) } - resp, err := http.Get(assetURL) + resp, err := d.client.Get(assetURL) if err != nil { fmt.Printf("Error getting asset %s: %v\n", assetURL, err) return @@ -131,7 +136,6 @@ func (d *Downloader) downloadAsset(assetURL string) { d.dn.AddData(relPath, body) } -// getRelativePath returns the path within the DataNode for the given page URL. func (d *Downloader) getRelativePath(pageURL string) string { u, err := url.Parse(pageURL) if err != nil { @@ -140,7 +144,6 @@ func (d *Downloader) getRelativePath(pageURL string) string { return strings.TrimPrefix(u.Path, "/") } -// resolveURL resolves ref against base and returns the absolute URL string. func (d *Downloader) resolveURL(base, ref string) (string, error) { baseURL, err := url.Parse(base) if err != nil { @@ -153,7 +156,6 @@ func (d *Downloader) resolveURL(base, ref string) (string, error) { return baseURL.ResolveReference(refURL).String(), nil } -// isLocal reports whether pageURL shares the same hostname as the base URL. func (d *Downloader) isLocal(pageURL string) bool { u, err := url.Parse(pageURL) if err != nil { @@ -162,7 +164,6 @@ func (d *Downloader) isLocal(pageURL string) bool { return u.Hostname() == d.baseURL.Hostname() } -// isAsset reports whether the URL likely points to a static asset by file extension. func isAsset(pageURL string) bool { ext := []string{".css", ".js", ".png", ".jpg", ".jpeg", ".gif", ".svg", ".ico"} for _, e := range ext {