feat: Improve test coverage and refactor for testability

This commit introduces a significant refactoring of the `cmd` package to improve testability and increases test coverage across the application.

Key changes include:
- Refactored Cobra commands to use `RunE` for better error handling and testing.
- Extracted business logic from command handlers into separate, testable functions.
- Added comprehensive unit tests for the `cmd`, `compress`, `github`, `logger`, and `pwa` packages.
- Added tests for missing command-line arguments, as requested.
- Implemented the `borg all` command to clone all public repositories for a GitHub user or organization.
- Restored and improved the `collect pwa` functionality.
- Removed duplicate code and fixed various bugs.
This commit is contained in:
google-labs-jules[bot] 2025-11-03 16:31:26 +00:00
parent 057d4a55d0
commit d27d9f3a37
21 changed files with 1070 additions and 444 deletions

View file

@ -2,65 +2,129 @@ package cmd
import (
"fmt"
"log/slog"
"io"
"io/fs"
"os"
"strings"
"github.com/Snider/Borg/pkg/compress"
"github.com/Snider/Borg/pkg/datanode"
"github.com/Snider/Borg/pkg/github"
"github.com/Snider/Borg/pkg/matrix"
"github.com/Snider/Borg/pkg/ui"
"github.com/Snider/Borg/pkg/vcs"
"github.com/spf13/cobra"
)
// allCmd represents the all command
var allCmd = &cobra.Command{
Use: "all [user/org]",
Short: "Collect all public repositories from a user or organization",
Long: `Collect all public repositories from a user or organization and store them in a DataNode.`,
Use: "all [url]",
Short: "Collect all resources from a URL",
Long: `Collect all resources from a URL, dispatching to the appropriate collector based on the URL type.`,
Args: cobra.ExactArgs(1),
Run: func(cmd *cobra.Command, args []string) {
logVal := cmd.Context().Value("logger")
log, ok := logVal.(*slog.Logger)
if !ok || log == nil {
fmt.Fprintln(os.Stderr, "Error: logger not properly initialised")
return
}
repos, err := GithubClient.GetPublicRepos(cmd.Context(), args[0])
if err != nil {
log.Error("failed to get public repos", "err", err)
return
RunE: func(cmd *cobra.Command, args []string) error {
url := args[0]
outputFile, _ := cmd.Flags().GetString("output")
format, _ := cmd.Flags().GetString("format")
compression, _ := cmd.Flags().GetString("compression")
// For now, we only support github user/org urls
if !strings.Contains(url, "github.com") {
return fmt.Errorf("unsupported URL type: %s", url)
}
outputDir, _ := cmd.Flags().GetString("output")
owner, _, err := github.ParseRepoFromURL(url)
if err != nil {
return fmt.Errorf("failed to parse repository url: %w", err)
}
repos, err := GithubClient.GetPublicRepos(cmd.Context(), owner)
if err != nil {
return err
}
prompter := ui.NewNonInteractivePrompter(ui.GetVCSQuote)
prompter.Start()
defer prompter.Stop()
var progressWriter io.Writer
if prompter.IsInteractive() {
bar := ui.NewProgressBar(len(repos), "Cloning repositories")
progressWriter = ui.NewProgressWriter(bar)
}
cloner := vcs.NewGitCloner()
allDataNodes := datanode.New()
for _, repoURL := range repos {
log.Info("cloning repository", "url", repoURL)
bar := ui.NewProgressBar(-1, "Cloning repository")
dn, err := GitCloner.CloneGitRepository(repoURL, bar)
bar.Finish()
dn, err := cloner.CloneGitRepository(repoURL, progressWriter)
if err != nil {
log.Error("failed to clone repository", "url", repoURL, "err", err)
// Log the error and continue
fmt.Fprintln(cmd.ErrOrStderr(), "Error cloning repository:", err)
continue
}
data, err := dn.ToTar()
// This is not an efficient way to merge datanodes, but it's the only way for now
// A better approach would be to add a Merge method to the DataNode
err = dn.Walk(".", func(path string, de fs.DirEntry, err error) error {
if err != nil {
log.Error("failed to serialize datanode", "url", repoURL, "err", err)
continue
return err
}
repoName := strings.Split(repoURL, "/")[len(strings.Split(repoURL, "/"))-1]
outputFile := fmt.Sprintf("%s/%s.dat", outputDir, repoName)
err = os.WriteFile(outputFile, data, 0644)
if !de.IsDir() {
file, err := dn.Open(path)
if err != nil {
log.Error("failed to write datanode to file", "url", repoURL, "err", err)
return err
}
defer file.Close()
data, err := io.ReadAll(file)
if err != nil {
return err
}
allDataNodes.AddData(path, data)
}
return nil
})
if err != nil {
fmt.Fprintln(cmd.ErrOrStderr(), "Error walking datanode:", err)
continue
}
}
var data []byte
if format == "matrix" {
matrix, err := matrix.FromDataNode(allDataNodes)
if err != nil {
return fmt.Errorf("error creating matrix: %w", err)
}
data, err = matrix.ToTar()
if err != nil {
return fmt.Errorf("error serializing matrix: %w", err)
}
} else {
data, err = allDataNodes.ToTar()
if err != nil {
return fmt.Errorf("error serializing DataNode: %w", err)
}
}
compressedData, err := compress.Compress(data, compression)
if err != nil {
return fmt.Errorf("error compressing data: %w", err)
}
err = os.WriteFile(outputFile, compressedData, 0644)
if err != nil {
return fmt.Errorf("error writing DataNode to file: %w", err)
}
fmt.Fprintln(cmd.OutOrStdout(), "All repositories saved to", outputFile)
return nil
},
}
func init() {
RootCmd.AddCommand(allCmd)
allCmd.PersistentFlags().String("output", ".", "Output directory for the DataNodes")
allCmd.PersistentFlags().String("output", "all.dat", "Output file for the DataNode")
allCmd.PersistentFlags().String("format", "datanode", "Output format (datanode or matrix)")
allCmd.PersistentFlags().String("compression", "none", "Compression format (none, gz, or xz)")
}

View file

@ -7,10 +7,13 @@ import (
// collectCmd represents the collect command
var collectCmd = &cobra.Command{
Use: "collect",
Short: "Collect a resource and store it in a DataNode.",
Long: `Collect a resource from a git repository, a website, or other URI and store it in a DataNode.`,
Short: "Collect a resource from a URI.",
Long: `Collect a resource from a URI and store it in a DataNode.`,
}
func init() {
RootCmd.AddCommand(collectCmd)
}
func NewCollectCmd() *cobra.Command {
return collectCmd
}

View file

@ -14,3 +14,6 @@ var collectGithubCmd = &cobra.Command{
func init() {
collectCmd.AddCommand(collectGithubCmd)
}
func NewCollectGithubCmd() *cobra.Command {
return collectGithubCmd
}

View file

@ -1,18 +1,16 @@
package cmd
import (
"bytes"
"errors"
"fmt"
"io"
"log/slog"
"net/http"
"os"
"path/filepath"
"strings"
"github.com/Snider/Borg/pkg/datanode"
borg_github "github.com/Snider/Borg/pkg/github"
gh "github.com/google/go-github/v39/github"
"github.com/google/go-github/v39/github"
"github.com/spf13/cobra"
"golang.org/x/mod/semver"
)
@ -23,12 +21,11 @@ var collectGithubReleaseCmd = &cobra.Command{
Short: "Download the latest release of a file from GitHub releases",
Long: `Download the latest release of a file from GitHub releases. If the file or URL has a version number, it will check for a higher version and download it if found.`,
Args: cobra.ExactArgs(1),
Run: func(cmd *cobra.Command, args []string) {
RunE: func(cmd *cobra.Command, args []string) error {
logVal := cmd.Context().Value("logger")
log, ok := logVal.(*slog.Logger)
if !ok || log == nil {
fmt.Fprintln(os.Stderr, "Error: logger not properly initialised")
return
return errors.New("logger not properly initialised")
}
repoURL := args[0]
outputDir, _ := cmd.Flags().GetString("output")
@ -36,93 +33,8 @@ var collectGithubReleaseCmd = &cobra.Command{
file, _ := cmd.Flags().GetString("file")
version, _ := cmd.Flags().GetString("version")
owner, repo, err := borg_github.ParseRepoFromURL(repoURL)
if err != nil {
log.Error("failed to parse repository url", "err", err)
return
}
release, err := borg_github.GetLatestRelease(owner, repo)
if err != nil {
log.Error("failed to get latest release", "err", err)
return
}
log.Info("found latest release", "tag", release.GetTagName())
if version != "" {
if !semver.IsValid(version) {
log.Error("invalid version string", "version", version)
return
}
if semver.Compare(release.GetTagName(), version) <= 0 {
log.Info("latest release is not newer than the provided version", "latest", release.GetTagName(), "provided", version)
return
}
}
if pack {
dn := datanode.New()
for _, asset := range release.Assets {
log.Info("downloading asset", "name", asset.GetName())
resp, err := http.Get(asset.GetBrowserDownloadURL())
if err != nil {
log.Error("failed to download asset", "name", asset.GetName(), "err", err)
continue
}
defer resp.Body.Close()
var buf bytes.Buffer
_, err = io.Copy(&buf, resp.Body)
if err != nil {
log.Error("failed to read asset", "name", asset.GetName(), "err", err)
continue
}
dn.AddData(asset.GetName(), buf.Bytes())
}
tar, err := dn.ToTar()
if err != nil {
log.Error("failed to create datanode", "err", err)
return
}
outputFile := outputDir
if !strings.HasSuffix(outputFile, ".dat") {
outputFile = outputFile + ".dat"
}
err = os.WriteFile(outputFile, tar, 0644)
if err != nil {
log.Error("failed to write datanode", "err", err)
return
}
log.Info("datanode saved", "path", outputFile)
} else {
if len(release.Assets) == 0 {
log.Info("no assets found in the latest release")
return
}
var assetToDownload *gh.ReleaseAsset
if file != "" {
for _, asset := range release.Assets {
if asset.GetName() == file {
assetToDownload = asset
break
}
}
if assetToDownload == nil {
log.Error("asset not found in the latest release", "asset", file)
return
}
} else {
assetToDownload = release.Assets[0]
}
outputPath := filepath.Join(outputDir, assetToDownload.GetName())
log.Info("downloading asset", "name", assetToDownload.GetName())
err = borg_github.DownloadReleaseAsset(assetToDownload, outputPath)
if err != nil {
log.Error("failed to download asset", "name", assetToDownload.GetName(), "err", err)
return
}
log.Info("asset downloaded", "path", outputPath)
}
_, err := GetRelease(log, repoURL, outputDir, pack, file, version)
return err
},
}
@ -133,3 +45,86 @@ func init() {
collectGithubReleaseCmd.PersistentFlags().String("file", "", "The file to download from the release")
collectGithubReleaseCmd.PersistentFlags().String("version", "", "The version to check against")
}
func NewCollectGithubReleaseCmd() *cobra.Command {
return collectGithubReleaseCmd
}
func GetRelease(log *slog.Logger, repoURL string, outputDir string, pack bool, file string, version string) (*github.RepositoryRelease, error) {
owner, repo, err := borg_github.ParseRepoFromURL(repoURL)
if err != nil {
return nil, fmt.Errorf("failed to parse repository url: %w", err)
}
release, err := borg_github.GetLatestRelease(owner, repo)
if err != nil {
return nil, fmt.Errorf("failed to get latest release: %w", err)
}
log.Info("found latest release", "tag", release.GetTagName())
if version != "" {
if !semver.IsValid(version) {
return nil, fmt.Errorf("invalid version string: %s", version)
}
if semver.Compare(release.GetTagName(), version) <= 0 {
log.Info("latest release is not newer than the provided version", "latest", release.GetTagName(), "provided", version)
return nil, nil
}
}
if pack {
dn := datanode.New()
for _, asset := range release.Assets {
log.Info("downloading asset", "name", asset.GetName())
data, err := borg_github.DownloadReleaseAsset(asset)
if err != nil {
log.Error("failed to download asset", "name", asset.GetName(), "err", err)
continue
}
dn.AddData(asset.GetName(), data)
}
tar, err := dn.ToTar()
if err != nil {
return nil, fmt.Errorf("failed to create datanode: %w", err)
}
outputFile := outputDir
if !strings.HasSuffix(outputFile, ".dat") {
outputFile = outputFile + ".dat"
}
err = os.WriteFile(outputFile, tar, 0644)
if err != nil {
return nil, fmt.Errorf("failed to write datanode: %w", err)
}
log.Info("datanode saved", "path", outputFile)
} else {
if len(release.Assets) == 0 {
log.Info("no assets found in the latest release")
return nil, nil
}
var assetToDownload *github.ReleaseAsset
if file != "" {
for _, asset := range release.Assets {
if asset.GetName() == file {
assetToDownload = asset
break
}
}
if assetToDownload == nil {
return nil, fmt.Errorf("asset not found in the latest release: %s", file)
}
} else {
assetToDownload = release.Assets[0]
}
outputPath := filepath.Join(outputDir, assetToDownload.GetName())
log.Info("downloading asset", "name", assetToDownload.GetName())
data, err := borg_github.DownloadReleaseAsset(assetToDownload)
if err != nil {
return nil, fmt.Errorf("failed to download asset: %w", err)
}
err = os.WriteFile(outputPath, data, 0644)
if err != nil {
return nil, fmt.Errorf("failed to write asset to file: %w", err)
}
log.Info("asset downloaded", "path", outputPath)
}
return release, nil
}

View file

@ -18,13 +18,14 @@ var (
GitCloner = vcs.NewGitCloner()
)
// collectGithubRepoCmd represents the collect github repo command
var collectGithubRepoCmd = &cobra.Command{
// NewCollectGithubRepoCmd creates a new cobra command for collecting a single git repository.
func NewCollectGithubRepoCmd() *cobra.Command {
cmd := &cobra.Command{
Use: "repo [repository-url]",
Short: "Collect a single Git repository",
Long: `Collect a single Git repository and store it in a DataNode.`,
Args: cobra.ExactArgs(1),
Run: func(cmd *cobra.Command, args []string) {
RunE: func(cmd *cobra.Command, args []string) error {
repoURL := args[0]
outputFile, _ := cmd.Flags().GetString("output")
format, _ := cmd.Flags().GetString("format")
@ -42,34 +43,29 @@ var collectGithubRepoCmd = &cobra.Command{
dn, err := GitCloner.CloneGitRepository(repoURL, progressWriter)
if err != nil {
fmt.Fprintln(cmd.ErrOrStderr(), "Error cloning repository:", err)
return
return fmt.Errorf("Error cloning repository: %w", err)
}
var data []byte
if format == "matrix" {
matrix, err := matrix.FromDataNode(dn)
if err != nil {
fmt.Fprintln(cmd.ErrOrStderr(), "Error creating matrix:", err)
return
return fmt.Errorf("Error creating matrix: %w", err)
}
data, err = matrix.ToTar()
if err != nil {
fmt.Fprintln(cmd.ErrOrStderr(), "Error serializing matrix:", err)
return
return fmt.Errorf("Error serializing matrix: %w", err)
}
} else {
data, err = dn.ToTar()
if err != nil {
fmt.Fprintln(cmd.ErrOrStderr(), "Error serializing DataNode:", err)
return
return fmt.Errorf("Error serializing DataNode: %w", err)
}
}
compressedData, err := compress.Compress(data, compression)
if err != nil {
fmt.Fprintln(cmd.ErrOrStderr(), "Error compressing data:", err)
return
return fmt.Errorf("Error compressing data: %w", err)
}
if outputFile == "" {
@ -81,17 +77,19 @@ var collectGithubRepoCmd = &cobra.Command{
err = os.WriteFile(outputFile, compressedData, 0644)
if err != nil {
fmt.Fprintln(cmd.ErrOrStderr(), "Error writing DataNode to file:", err)
return
return fmt.Errorf("Error writing DataNode to file: %w", err)
}
fmt.Fprintln(cmd.OutOrStdout(), "Repository saved to", outputFile)
return nil
},
}
cmd.PersistentFlags().String("output", "", "Output file for the DataNode")
cmd.PersistentFlags().String("format", "datanode", "Output format (datanode or matrix)")
cmd.PersistentFlags().String("compression", "none", "Compression format (none, gz, or xz)")
return cmd
}
func init() {
collectGithubCmd.AddCommand(collectGithubRepoCmd)
collectGithubRepoCmd.PersistentFlags().String("output", "", "Output file for the DataNode")
collectGithubRepoCmd.PersistentFlags().String("format", "datanode", "Output format (datanode or matrix)")
collectGithubRepoCmd.PersistentFlags().String("compression", "none", "Compression format (none, gz, or xz)")
collectGithubCmd.AddCommand(NewCollectGithubRepoCmd())
}

View file

@ -12,69 +12,85 @@ import (
"github.com/spf13/cobra"
)
var (
// PWAClient is the pwa client used by the command. It can be replaced for testing.
PWAClient = pwa.NewPWAClient()
)
type CollectPWACmd struct {
cobra.Command
PWAClient pwa.PWAClient
}
// collectPWACmd represents the collect pwa command
var collectPWACmd = &cobra.Command{
// NewCollectPWACmd creates a new collect pwa command
func NewCollectPWACmd() *CollectPWACmd {
c := &CollectPWACmd{
PWAClient: pwa.NewPWAClient(),
}
c.Command = cobra.Command{
Use: "pwa",
Short: "Collect a single PWA using a URI",
Long: `Collect a single PWA and store it in a DataNode.
Example:
borg collect pwa --uri https://example.com --output mypwa.dat`,
Run: func(cmd *cobra.Command, args []string) {
RunE: func(cmd *cobra.Command, args []string) error {
pwaURL, _ := cmd.Flags().GetString("uri")
outputFile, _ := cmd.Flags().GetString("output")
format, _ := cmd.Flags().GetString("format")
compression, _ := cmd.Flags().GetString("compression")
err := CollectPWA(c.PWAClient, pwaURL, outputFile, format, compression)
if err != nil {
return err
}
fmt.Fprintln(cmd.OutOrStdout(), "PWA saved to", outputFile)
return nil
},
}
c.Flags().String("uri", "", "The URI of the PWA to collect")
c.Flags().String("output", "", "Output file for the DataNode")
c.Flags().String("format", "datanode", "Output format (datanode or matrix)")
c.Flags().String("compression", "none", "Compression format (none, gz, or xz)")
return c
}
func init() {
collectCmd.AddCommand(&NewCollectPWACmd().Command)
}
func CollectPWA(client pwa.PWAClient, pwaURL string, outputFile string, format string, compression string) error {
if pwaURL == "" {
fmt.Fprintln(cmd.ErrOrStderr(), "Error: uri is required")
return
return fmt.Errorf("uri is required")
}
bar := ui.NewProgressBar(-1, "Finding PWA manifest")
defer bar.Finish()
manifestURL, err := PWAClient.FindManifest(pwaURL)
manifestURL, err := client.FindManifest(pwaURL)
if err != nil {
fmt.Fprintln(cmd.ErrOrStderr(), "Error finding manifest:", err)
return
return fmt.Errorf("error finding manifest: %w", err)
}
bar.Describe("Downloading and packaging PWA")
dn, err := PWAClient.DownloadAndPackagePWA(pwaURL, manifestURL, bar)
dn, err := client.DownloadAndPackagePWA(pwaURL, manifestURL, bar)
if err != nil {
fmt.Fprintln(cmd.ErrOrStderr(), "Error downloading and packaging PWA:", err)
return
return fmt.Errorf("error downloading and packaging PWA: %w", err)
}
var data []byte
if format == "matrix" {
matrix, err := matrix.FromDataNode(dn)
if err != nil {
fmt.Fprintln(cmd.ErrOrStderr(), "Error creating matrix:", err)
return
return fmt.Errorf("error creating matrix: %w", err)
}
data, err = matrix.ToTar()
if err != nil {
fmt.Fprintln(cmd.ErrOrStderr(), "Error serializing matrix:", err)
return
return fmt.Errorf("error serializing matrix: %w", err)
}
} else {
data, err = dn.ToTar()
if err != nil {
fmt.Fprintln(cmd.ErrOrStderr(), "Error serializing DataNode:", err)
return
return fmt.Errorf("error serializing DataNode: %w", err)
}
}
compressedData, err := compress.Compress(data, compression)
if err != nil {
fmt.Fprintln(cmd.ErrOrStderr(), "Error compressing data:", err)
return
return fmt.Errorf("error compressing data: %w", err)
}
if outputFile == "" {
@ -86,18 +102,7 @@ Example:
err = os.WriteFile(outputFile, compressedData, 0644)
if err != nil {
fmt.Fprintln(cmd.ErrOrStderr(), "Error writing PWA to file:", err)
return
return fmt.Errorf("error writing PWA to file: %w", err)
}
fmt.Fprintln(cmd.OutOrStdout(), "PWA saved to", outputFile)
},
}
func init() {
collectCmd.AddCommand(collectPWACmd)
collectPWACmd.Flags().String("uri", "", "The URI of the PWA to collect")
collectPWACmd.Flags().String("output", "", "Output file for the DataNode")
collectPWACmd.Flags().String("format", "datanode", "Output format (datanode or matrix)")
collectPWACmd.Flags().String("compression", "none", "Compression format (none, gz, or xz)")
return nil
}

26
cmd/collect_pwa_test.go Normal file
View file

@ -0,0 +1,26 @@
package cmd
import (
"strings"
"testing"
)
func TestCollectPWACmd_NoURI(t *testing.T) {
rootCmd := NewRootCmd()
collectCmd := NewCollectCmd()
collectPWACmd := NewCollectPWACmd()
collectCmd.AddCommand(&collectPWACmd.Command)
rootCmd.AddCommand(collectCmd)
_, err := executeCommand(rootCmd, "collect", "pwa")
if err == nil {
t.Fatalf("expected an error, but got none")
}
if !strings.Contains(err.Error(), "uri is required") {
t.Fatalf("unexpected error message: %v", err)
}
}
func Test_NewCollectPWACmd(t *testing.T) {
if NewCollectPWACmd() == nil {
t.Errorf("NewCollectPWACmd is nil")
}
}

View file

@ -19,7 +19,7 @@ var collectWebsiteCmd = &cobra.Command{
Short: "Collect a single website",
Long: `Collect a single website and store it in a DataNode.`,
Args: cobra.ExactArgs(1),
Run: func(cmd *cobra.Command, args []string) {
RunE: func(cmd *cobra.Command, args []string) error {
websiteURL := args[0]
outputFile, _ := cmd.Flags().GetString("output")
depth, _ := cmd.Flags().GetInt("depth")
@ -36,34 +36,29 @@ var collectWebsiteCmd = &cobra.Command{
dn, err := website.DownloadAndPackageWebsite(websiteURL, depth, bar)
if err != nil {
fmt.Fprintln(cmd.ErrOrStderr(), "Error downloading and packaging website:", err)
return
return fmt.Errorf("error downloading and packaging website: %w", err)
}
var data []byte
if format == "matrix" {
matrix, err := matrix.FromDataNode(dn)
if err != nil {
fmt.Fprintln(cmd.ErrOrStderr(), "Error creating matrix:", err)
return
return fmt.Errorf("error creating matrix: %w", err)
}
data, err = matrix.ToTar()
if err != nil {
fmt.Fprintln(cmd.ErrOrStderr(), "Error serializing matrix:", err)
return
return fmt.Errorf("error serializing matrix: %w", err)
}
} else {
data, err = dn.ToTar()
if err != nil {
fmt.Fprintln(cmd.ErrOrStderr(), "Error serializing DataNode:", err)
return
return fmt.Errorf("error serializing DataNode: %w", err)
}
}
compressedData, err := compress.Compress(data, compression)
if err != nil {
fmt.Fprintln(cmd.ErrOrStderr(), "Error compressing data:", err)
return
return fmt.Errorf("error compressing data: %w", err)
}
if outputFile == "" {
@ -75,11 +70,11 @@ var collectWebsiteCmd = &cobra.Command{
err = os.WriteFile(outputFile, compressedData, 0644)
if err != nil {
fmt.Fprintln(cmd.ErrOrStderr(), "Error writing website to file:", err)
return
return fmt.Errorf("error writing website to file: %w", err)
}
fmt.Fprintln(cmd.OutOrStdout(), "Website saved to", outputFile)
return nil
},
}
@ -90,3 +85,6 @@ func init() {
collectWebsiteCmd.PersistentFlags().String("format", "datanode", "Output format (datanode or matrix)")
collectWebsiteCmd.PersistentFlags().String("compression", "none", "Compression format (none, gz, or xz)")
}
func NewCollectWebsiteCmd() *cobra.Command {
return collectWebsiteCmd
}

View file

@ -0,0 +1,26 @@
package cmd
import (
"strings"
"testing"
)
func TestCollectWebsiteCmd_NoArgs(t *testing.T) {
rootCmd := NewRootCmd()
collectCmd := NewCollectCmd()
collectWebsiteCmd := NewCollectWebsiteCmd()
collectCmd.AddCommand(collectWebsiteCmd)
rootCmd.AddCommand(collectCmd)
_, err := executeCommand(rootCmd, "collect", "website")
if err == nil {
t.Fatalf("expected an error, but got none")
}
if !strings.Contains(err.Error(), "accepts 1 arg(s), received 0") {
t.Fatalf("unexpected error message: %v", err)
}
}
func Test_NewCollectWebsiteCmd(t *testing.T) {
if NewCollectWebsiteCmd() == nil {
t.Errorf("NewCollectWebsiteCmd is nil")
}
}

View file

@ -1,18 +0,0 @@
package cmd
import (
"log/slog"
"os"
"testing"
)
func TestExecute(t *testing.T) {
// This test simply checks that the Execute function can be called without error.
// It doesn't actually test any of the application's functionality.
log := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
Level: slog.LevelDebug,
}))
if err := Execute(log); err != nil {
t.Errorf("Execute() failed: %v", err)
}
}

64
cmd/root_test.go Normal file
View file

@ -0,0 +1,64 @@
package cmd
import (
"bytes"
"io"
"log/slog"
"testing"
"github.com/spf13/cobra"
)
func TestExecute(t *testing.T) {
err := Execute(slog.New(slog.NewTextHandler(io.Discard, nil)))
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
}
func executeCommand(cmd *cobra.Command, args ...string) (string, error) {
buf := new(bytes.Buffer)
cmd.SetOut(buf)
cmd.SetErr(buf)
cmd.SetArgs(args)
err := cmd.Execute()
return buf.String(), err
}
func Test_NewRootCmd(t *testing.T) {
if NewRootCmd() == nil {
t.Errorf("NewRootCmd is nil")
}
}
func Test_executeCommand(t *testing.T) {
type args struct {
cmd *cobra.Command
args []string
}
tests := []struct {
name string
args args
want string
wantErr bool
}{
{
name: "Test with no args",
args: args{
cmd: NewRootCmd(),
args: []string{},
},
want: "",
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := executeCommand(tt.args.cmd, tt.args.args...)
if (err != nil) != tt.wantErr {
t.Errorf("executeCommand() error = %v, wantErr %v", err, tt.wantErr)
return
}
})
}
}

12
main.go
View file

@ -7,11 +7,21 @@ import (
"github.com/Snider/Borg/pkg/logger"
)
var osExit = os.Exit
func main() {
verbose, _ := cmd.RootCmd.PersistentFlags().GetBool("verbose")
log := logger.New(verbose)
if err := cmd.Execute(log); err != nil {
log.Error("fatal error", "err", err)
os.Exit(1)
osExit(1)
}
}
func Main() {
verbose, _ := cmd.RootCmd.PersistentFlags().GetBool("verbose")
log := logger.New(verbose)
if err := cmd.Execute(log); err != nil {
log.Error("fatal error", "err", err)
osExit(1)
}
}

View file

@ -0,0 +1,59 @@
package compress
import (
"bytes"
"testing"
)
func TestCompressDecompress(t *testing.T) {
testData := []byte("hello, world")
// Test gzip compression
compressedGz, err := Compress(testData, "gz")
if err != nil {
t.Fatalf("gzip compression failed: %v", err)
}
decompressedGz, err := Decompress(compressedGz)
if err != nil {
t.Fatalf("gzip decompression failed: %v", err)
}
if !bytes.Equal(testData, decompressedGz) {
t.Errorf("gzip decompressed data does not match original data")
}
// Test xz compression
compressedXz, err := Compress(testData, "xz")
if err != nil {
t.Fatalf("xz compression failed: %v", err)
}
decompressedXz, err := Decompress(compressedXz)
if err != nil {
t.Fatalf("xz decompression failed: %v", err)
}
if !bytes.Equal(testData, decompressedXz) {
t.Errorf("xz decompressed data does not match original data")
}
// Test no compression
compressedNone, err := Compress(testData, "none")
if err != nil {
t.Fatalf("no compression failed: %v", err)
}
if !bytes.Equal(testData, compressedNone) {
t.Errorf("no compression data does not match original data")
}
decompressedNone, err := Decompress(compressedNone)
if err != nil {
t.Fatalf("no compression decompression failed: %v", err)
}
if !bytes.Equal(testData, decompressedNone) {
t.Errorf("no compression decompressed data does not match original data")
}
}

View file

@ -27,11 +27,8 @@ func NewGithubClient() GithubClient {
type githubClient struct{}
func (g *githubClient) GetPublicRepos(ctx context.Context, userOrOrg string) ([]string, error) {
return g.getPublicReposWithAPIURL(ctx, "https://api.github.com", userOrOrg)
}
func (g *githubClient) newAuthenticatedClient(ctx context.Context) *http.Client {
// NewAuthenticatedClient creates a new authenticated http client.
var NewAuthenticatedClient = func(ctx context.Context) *http.Client {
token := os.Getenv("GITHUB_TOKEN")
if token == "" {
return http.DefaultClient
@ -42,8 +39,12 @@ func (g *githubClient) newAuthenticatedClient(ctx context.Context) *http.Client
return oauth2.NewClient(ctx, ts)
}
func (g *githubClient) GetPublicRepos(ctx context.Context, userOrOrg string) ([]string, error) {
return g.getPublicReposWithAPIURL(ctx, "https://api.github.com", userOrOrg)
}
func (g *githubClient) getPublicReposWithAPIURL(ctx context.Context, apiURL, userOrOrg string) ([]string, error) {
client := g.newAuthenticatedClient(ctx)
client := NewAuthenticatedClient(ctx)
var allCloneURLs []string
url := fmt.Sprintf("%s/users/%s/repos", apiURL, userOrOrg)

109
pkg/github/github_test.go Normal file
View file

@ -0,0 +1,109 @@
package github
import (
"bytes"
"context"
"io"
"net/http"
"net/url"
"testing"
"github.com/Snider/Borg/pkg/mocks"
)
func TestGetPublicRepos(t *testing.T) {
mockClient := mocks.NewMockClient(map[string]*http.Response{
"https://api.github.com/users/testuser/repos": {
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(bytes.NewBufferString(`[{"clone_url": "https://github.com/testuser/repo1.git"}]`)),
},
"https://api.github.com/orgs/testorg/repos": {
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}, "Link": []string{`<https://api.github.com/organizations/123/repos?page=2>; rel="next"`}},
Body: io.NopCloser(bytes.NewBufferString(`[{"clone_url": "https://github.com/testorg/repo1.git"}]`)),
},
"https://api.github.com/organizations/123/repos?page=2": {
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(bytes.NewBufferString(`[{"clone_url": "https://github.com/testorg/repo2.git"}]`)),
},
})
client := &githubClient{}
oldClient := NewAuthenticatedClient
NewAuthenticatedClient = func(ctx context.Context) *http.Client {
return mockClient
}
defer func() {
NewAuthenticatedClient = oldClient
}()
// Test user repos
repos, err := client.getPublicReposWithAPIURL(context.Background(), "https://api.github.com", "testuser")
if err != nil {
t.Fatalf("getPublicReposWithAPIURL for user failed: %v", err)
}
if len(repos) != 1 || repos[0] != "https://github.com/testuser/repo1.git" {
t.Errorf("unexpected user repos: %v", repos)
}
// Test org repos with pagination
repos, err = client.getPublicReposWithAPIURL(context.Background(), "https://api.github.com", "testorg")
if err != nil {
t.Fatalf("getPublicReposWithAPIURL for org failed: %v", err)
}
if len(repos) != 2 || repos[0] != "https://github.com/testorg/repo1.git" || repos[1] != "https://github.com/testorg/repo2.git" {
t.Errorf("unexpected org repos: %v", repos)
}
}
func TestGetPublicRepos_Error(t *testing.T) {
u, _ := url.Parse("https://api.github.com/users/testuser/repos")
mockClient := mocks.NewMockClient(map[string]*http.Response{
"https://api.github.com/users/testuser/repos": {
StatusCode: http.StatusNotFound,
Status: "404 Not Found",
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(bytes.NewBufferString("")),
Request: &http.Request{Method: "GET", URL: u},
},
"https://api.github.com/orgs/testuser/repos": {
StatusCode: http.StatusNotFound,
Status: "404 Not Found",
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(bytes.NewBufferString("")),
Request: &http.Request{Method: "GET", URL: u},
},
})
expectedErr := "failed to fetch repos: 404 Not Found"
client := &githubClient{}
oldClient := NewAuthenticatedClient
NewAuthenticatedClient = func(ctx context.Context) *http.Client {
return mockClient
}
defer func() {
NewAuthenticatedClient = oldClient
}()
// Test user repos
_, err := client.getPublicReposWithAPIURL(context.Background(), "https://api.github.com", "testuser")
if err.Error() != expectedErr {
t.Fatalf("getPublicReposWithAPIURL for user failed: %v", err)
}
}
func TestFindNextURL(t *testing.T) {
client := &githubClient{}
linkHeader := `<https://api.github.com/organizations/123/repos?page=2>; rel="next", <https://api.github.com/organizations/123/repos?page=1>; rel="prev"`
nextURL := client.findNextURL(linkHeader)
if nextURL != "https://api.github.com/organizations/123/repos?page=2" {
t.Errorf("unexpected next URL: %s", nextURL)
}
linkHeader = `<https://api.github.com/organizations/123/repos?page=1>; rel="prev"`
nextURL = client.findNextURL(linkHeader)
if nextURL != "" {
t.Errorf("unexpected next URL: %s", nextURL)
}
}

View file

@ -1,19 +1,33 @@
package github
import (
"bytes"
"context"
"fmt"
"io"
"net/http"
"os"
"strings"
"github.com/google/go-github/v39/github"
)
var (
// NewClient is a variable that holds the function to create a new GitHub client.
// This allows for mocking in tests.
NewClient = func(httpClient *http.Client) *github.Client {
return github.NewClient(httpClient)
}
// NewRequest is a variable that holds the function to create a new HTTP request.
NewRequest = func(method, url string, body io.Reader) (*http.Request, error) {
return http.NewRequest(method, url, body)
}
// DefaultClient is the default http client
DefaultClient = &http.Client{}
)
// GetLatestRelease gets the latest release for a repository.
func GetLatestRelease(owner, repo string) (*github.RepositoryRelease, error) {
client := github.NewClient(nil)
client := NewClient(nil)
release, _, err := client.Repositories.GetLatestRelease(context.Background(), owner, repo)
if err != nil {
return nil, err
@ -22,32 +36,29 @@ func GetLatestRelease(owner, repo string) (*github.RepositoryRelease, error) {
}
// DownloadReleaseAsset downloads a release asset.
func DownloadReleaseAsset(asset *github.ReleaseAsset, path string) error {
client := &http.Client{}
req, err := http.NewRequest("GET", asset.GetBrowserDownloadURL(), nil)
func DownloadReleaseAsset(asset *github.ReleaseAsset) ([]byte, error) {
req, err := NewRequest("GET", asset.GetBrowserDownloadURL(), nil)
if err != nil {
return err
return nil, err
}
req.Header.Set("Accept", "application/octet-stream")
resp, err := client.Do(req)
resp, err := DefaultClient.Do(req)
if err != nil {
return err
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("bad status: %s", resp.Status)
return nil, fmt.Errorf("bad status: %s", resp.Status)
}
f, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY, 0644)
buf := new(bytes.Buffer)
_, err = io.Copy(buf, resp.Body)
if err != nil {
return err
return nil, err
}
defer f.Close()
_, err = io.Copy(f, resp.Body)
return err
return buf.Bytes(), nil
}
// ParseRepoFromURL parses the owner and repository from a GitHub URL.

185
pkg/github/release_test.go Normal file
View file

@ -0,0 +1,185 @@
package github
import (
"bytes"
"fmt"
"io"
"net/http"
"net/url"
"testing"
"github.com/Snider/Borg/pkg/mocks"
"github.com/google/go-github/v39/github"
)
type errorRoundTripper struct{}
func (e *errorRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
return nil, fmt.Errorf("do error")
}
func TestParseRepoFromURL(t *testing.T) {
testCases := []struct {
url string
owner string
repo string
expectErr bool
}{
{"https://github.com/owner/repo.git", "owner", "repo", false},
{"http://github.com/owner/repo", "owner", "repo", false},
{"git://github.com/owner/repo.git", "owner", "repo", false},
{"github.com/owner/repo", "owner", "repo", false},
{"git@github.com:owner/repo.git", "owner", "repo", false},
{"https://github.com/owner/repo/tree/main", "", "", true},
{"invalid-url", "", "", true},
}
for _, tc := range testCases {
owner, repo, err := ParseRepoFromURL(tc.url)
if (err != nil) != tc.expectErr {
t.Errorf("unexpected error for URL %s: %v", tc.url, err)
}
if owner != tc.owner || repo != tc.repo {
t.Errorf("unexpected owner/repo for URL %s: %s/%s", tc.url, owner, repo)
}
}
}
func TestGetLatestRelease(t *testing.T) {
mockClient := mocks.NewMockClient(map[string]*http.Response{
"https://api.github.com/repos/owner/repo/releases/latest": {
StatusCode: http.StatusOK,
Body: io.NopCloser(bytes.NewBufferString(`{"tag_name": "v1.0.0"}`)),
},
})
client := github.NewClient(mockClient)
NewClient = func(_ *http.Client) *github.Client {
return client
}
release, err := GetLatestRelease("owner", "repo")
if err != nil {
t.Fatalf("GetLatestRelease failed: %v", err)
}
if release.GetTagName() != "v1.0.0" {
t.Errorf("unexpected tag name: %s", release.GetTagName())
}
}
func TestDownloadReleaseAsset(t *testing.T) {
mockClient := mocks.NewMockClient(map[string]*http.Response{
"https://github.com/owner/repo/releases/download/v1.0.0/asset.zip": {
StatusCode: http.StatusOK,
Body: io.NopCloser(bytes.NewBufferString("asset content")),
},
})
asset := &github.ReleaseAsset{
BrowserDownloadURL: github.String("https://github.com/owner/repo/releases/download/v1.0.0/asset.zip"),
}
oldClient := DefaultClient
DefaultClient = mockClient
defer func() {
DefaultClient = oldClient
}()
data, err := DownloadReleaseAsset(asset)
if err != nil {
t.Fatalf("DownloadReleaseAsset failed: %v", err)
}
if string(data) != "asset content" {
t.Errorf("unexpected asset content: %s", string(data))
}
}
func TestDownloadReleaseAsset_BadRequest(t *testing.T) {
mockClient := mocks.NewMockClient(map[string]*http.Response{
"https://github.com/owner/repo/releases/download/v1.0.0/asset.zip": {
StatusCode: http.StatusBadRequest,
Status: "400 Bad Request",
Body: io.NopCloser(bytes.NewBufferString("")),
},
})
expectedErr := "bad status: 400 Bad Request"
asset := &github.ReleaseAsset{
BrowserDownloadURL: github.String("https://github.com/owner/repo/releases/download/v1.0.0/asset.zip"),
}
oldClient := DefaultClient
DefaultClient = mockClient
defer func() {
DefaultClient = oldClient
}()
_, err := DownloadReleaseAsset(asset)
if err.Error() != expectedErr {
t.Fatalf("DownloadReleaseAsset failed: %v", err)
}
}
func TestDownloadReleaseAsset_NewRequestError(t *testing.T) {
errRequest := fmt.Errorf("bad request")
asset := &github.ReleaseAsset{
BrowserDownloadURL: github.String("https://github.com/owner/repo/releases/download/v1.0.0/asset.zip"),
}
oldNewRequest := NewRequest
NewRequest = func(method, url string, body io.Reader) (*http.Request, error) {
return nil, errRequest
}
defer func() {
NewRequest = oldNewRequest
}()
_, err := DownloadReleaseAsset(asset)
if err == nil {
t.Fatalf("DownloadReleaseAsset failed: %v", err)
}
}
func TestGetLatestRelease_Error(t *testing.T) {
u, _ := url.Parse("https://api.github.com/repos/owner/repo/releases/latest")
mockClient := mocks.NewMockClient(map[string]*http.Response{
"https://api.github.com/repos/owner/repo/releases/latest": {
StatusCode: http.StatusNotFound,
Status: "404 Not Found",
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(bytes.NewBufferString("")),
Request: &http.Request{Method: "GET", URL: u},
},
})
expectedErr := "GET https://api.github.com/repos/owner/repo/releases/latest: 404 []"
client := github.NewClient(mockClient)
NewClient = func(_ *http.Client) *github.Client {
return client
}
_, err := GetLatestRelease("owner", "repo")
if err.Error() != expectedErr {
t.Fatalf("GetLatestRelease failed: %v", err)
}
}
func TestDownloadReleaseAsset_DoError(t *testing.T) {
mockClient := &http.Client{
Transport: &errorRoundTripper{},
}
asset := &github.ReleaseAsset{
BrowserDownloadURL: github.String("https://github.com/owner/repo/releases/download/v1.0.0/asset.zip"),
}
oldClient := DefaultClient
DefaultClient = mockClient
defer func() {
DefaultClient = oldClient
}()
_, err := DownloadReleaseAsset(asset)
if err == nil {
t.Fatalf("DownloadReleaseAsset should have failed")
}
}

64
pkg/logger/logger_test.go Normal file
View file

@ -0,0 +1,64 @@
package logger
import (
"bytes"
"context"
"io"
"os"
"log/slog"
"testing"
)
func TestNew(t *testing.T) {
// Test non-verbose logger
var buf bytes.Buffer
oldStderr := os.Stderr
r, w, _ := os.Pipe()
os.Stderr = w
log := New(false)
log.Info("info message")
log.Debug("debug message")
w.Close()
os.Stderr = oldStderr
io.Copy(&buf, r)
if !bytes.Contains(buf.Bytes(), []byte("info message")) {
t.Errorf("expected info message to be logged")
}
if bytes.Contains(buf.Bytes(), []byte("debug message")) {
t.Errorf("expected debug message not to be logged")
}
// Test verbose logger
buf.Reset()
r, w, _ = os.Pipe()
os.Stderr = w
log = New(true)
log.Info("info message")
log.Debug("debug message")
w.Close()
os.Stderr = oldStderr
io.Copy(&buf, r)
if !bytes.Contains(buf.Bytes(), []byte("info message")) {
t.Errorf("expected info message to be logged")
}
if !bytes.Contains(buf.Bytes(), []byte("debug message")) {
t.Errorf("expected debug message to be logged")
}
}
func TestNew_Level(t *testing.T) {
log := New(true)
if !log.Enabled(context.Background(), slog.LevelDebug) {
t.Errorf("expected debug level to be enabled")
}
log = New(false)
if log.Enabled(context.Background(), slog.LevelDebug) {
t.Errorf("expected debug level to be disabled")
}
}

27
pkg/mocks/mock_vcs.go Normal file
View file

@ -0,0 +1,27 @@
package mocks
import (
"io"
"github.com/Snider/Borg/pkg/datanode"
"github.com/Snider/Borg/pkg/vcs"
)
// MockGitCloner is a mock implementation of the GitCloner interface.
type MockGitCloner struct {
DN *datanode.DataNode
Err error
}
// NewMockGitCloner creates a new MockGitCloner.
func NewMockGitCloner(dn *datanode.DataNode, err error) vcs.GitCloner {
return &MockGitCloner{
DN: dn,
Err: err,
}
}
// CloneGitRepository mocks the cloning of a Git repository.
func (m *MockGitCloner) CloneGitRepository(repoURL string, progress io.Writer) (*datanode.DataNode, error) {
return m.DN, m.Err
}

View file

@ -6,49 +6,31 @@ import (
"io"
"net/http"
"net/url"
"path"
"strings"
"github.com/Snider/Borg/pkg/datanode"
"github.com/schollz/progressbar/v3"
"golang.org/x/net/html"
)
// Manifest represents a simple PWA manifest structure.
type Manifest struct {
Name string `json:"name"`
ShortName string `json:"short_name"`
StartURL string `json:"start_url"`
Icons []Icon `json:"icons"`
}
// Icon represents an icon in the PWA manifest.
type Icon struct {
Src string `json:"src"`
Sizes string `json:"sizes"`
Type string `json:"type"`
}
// PWAClient is an interface for finding and downloading PWAs.
// PWAClient is an interface for interacting with PWAs.
type PWAClient interface {
FindManifest(pageURL string) (string, error)
DownloadAndPackagePWA(baseURL string, manifestURL string, bar *progressbar.ProgressBar) (*datanode.DataNode, error)
FindManifest(pwaURL string) (string, error)
DownloadAndPackagePWA(pwaURL, manifestURL string, bar *progressbar.ProgressBar) (*datanode.DataNode, error)
}
// NewPWAClient creates a new PWAClient.
func NewPWAClient() PWAClient {
return &pwaClient{
client: http.DefaultClient,
}
return &pwaClient{client: http.DefaultClient}
}
type pwaClient struct {
client *http.Client
}
// FindManifest finds the manifest URL from a given HTML page.
func (p *pwaClient) FindManifest(pageURL string) (string, error) {
resp, err := p.client.Get(pageURL)
// FindManifest finds the manifest for a PWA.
func (p *pwaClient) FindManifest(pwaURL string) (string, error) {
resp, err := p.client.Get(pwaURL)
if err != nil {
return "", err
}
@ -59,100 +41,124 @@ func (p *pwaClient) FindManifest(pageURL string) (string, error) {
return "", err
}
var manifestPath string
var manifestURL string
var f func(*html.Node)
f = func(n *html.Node) {
if n.Type == html.ElementNode && n.Data == "link" {
isManifest := false
var isManifest bool
var href string
for _, a := range n.Attr {
if a.Key == "rel" && a.Val == "manifest" {
isManifest = true
break
}
}
if isManifest {
for _, a := range n.Attr {
if a.Key == "href" {
manifestPath = a.Val
return // exit once found
href = a.Val
}
}
if isManifest && href != "" {
manifestURL = href
return
}
}
for c := n.FirstChild; c != nil && manifestPath == ""; c = c.NextSibling {
for c := n.FirstChild; c != nil && manifestURL == ""; c = c.NextSibling {
f(c)
}
}
f(doc)
if manifestPath == "" {
if manifestURL == "" {
return "", fmt.Errorf("manifest not found")
}
resolvedURL, err := p.resolveURL(pageURL, manifestPath)
resolvedURL, err := p.resolveURL(pwaURL, manifestURL)
if err != nil {
return "", fmt.Errorf("could not resolve manifest URL: %w", err)
return "", err
}
return resolvedURL.String(), nil
}
// DownloadAndPackagePWA downloads all assets of a PWA and packages them into a DataNode.
func (p *pwaClient) DownloadAndPackagePWA(baseURL string, manifestURL string, bar *progressbar.ProgressBar) (*datanode.DataNode, error) {
if bar == nil {
return nil, fmt.Errorf("progress bar cannot be nil")
}
manifestAbsURL, err := p.resolveURL(baseURL, manifestURL)
if err != nil {
return nil, fmt.Errorf("could not resolve manifest URL: %w", err)
// DownloadAndPackagePWA downloads and packages a PWA into a DataNode.
func (p *pwaClient) DownloadAndPackagePWA(pwaURL, manifestURL string, bar *progressbar.ProgressBar) (*datanode.DataNode, error) {
dn := datanode.New()
type Manifest struct {
StartURL string `json:"start_url"`
Icons []struct {
Src string `json:"src"`
} `json:"icons"`
}
resp, err := p.client.Get(manifestAbsURL.String())
downloadAndAdd := func(assetURL string) error {
if bar != nil {
bar.Add(1)
}
resp, err := p.client.Get(assetURL)
if err != nil {
return nil, fmt.Errorf("could not download manifest: %w", err)
return fmt.Errorf("failed to download %s: %w", assetURL, err)
}
defer resp.Body.Close()
manifestBody, err := io.ReadAll(resp.Body)
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("could not read manifest body: %w", err)
return fmt.Errorf("failed to read body of %s: %w", assetURL, err)
}
u, err := url.Parse(assetURL)
if err != nil {
return fmt.Errorf("failed to parse asset URL %s: %w", assetURL, err)
}
dn.AddData(strings.TrimPrefix(u.Path, "/"), body)
return nil
}
// Download manifest
if err := downloadAndAdd(manifestURL); err != nil {
return nil, err
}
// Parse manifest and download assets
var manifestPath string
u, parseErr := url.Parse(manifestURL)
if parseErr != nil {
manifestPath = "manifest.json"
} else {
manifestPath = strings.TrimPrefix(u.Path, "/")
}
manifestFile, err := dn.Open(manifestPath)
if err != nil {
return nil, fmt.Errorf("failed to open manifest from datanode: %w", err)
}
defer manifestFile.Close()
manifestData, err := io.ReadAll(manifestFile)
if err != nil {
return nil, fmt.Errorf("failed to read manifest from datanode: %w", err)
}
var manifest Manifest
if err := json.Unmarshal(manifestBody, &manifest); err != nil {
return nil, fmt.Errorf("could not parse manifest JSON: %w", err)
if err := json.Unmarshal(manifestData, &manifest); err != nil {
return nil, fmt.Errorf("failed to parse manifest: %w", err)
}
dn := datanode.New()
dn.AddData("manifest.json", manifestBody)
if manifest.StartURL != "" {
startURLAbs, err := p.resolveURL(manifestAbsURL.String(), manifest.StartURL)
// Download start_url
startURL, err := p.resolveURL(manifestURL, manifest.StartURL)
if err != nil {
return nil, fmt.Errorf("could not resolve start_url: %w", err)
}
err = p.downloadAndAddFile(dn, startURLAbs, manifest.StartURL, bar)
if err != nil {
return nil, fmt.Errorf("failed to download start_url asset: %w", err)
return nil, fmt.Errorf("failed to resolve start_url: %w", err)
}
if err := downloadAndAdd(startURL.String()); err != nil {
return nil, err
}
// Download icons
for _, icon := range manifest.Icons {
iconURLAbs, err := p.resolveURL(manifestAbsURL.String(), icon.Src)
iconURL, err := p.resolveURL(manifestURL, icon.Src)
if err != nil {
fmt.Printf("Warning: could not resolve icon URL %s: %v\n", icon.Src, err)
// Skip icons with bad URLs
continue
}
err = p.downloadAndAddFile(dn, iconURLAbs, icon.Src, bar)
if err != nil {
fmt.Printf("Warning: failed to download icon %s: %v\n", icon.Src, err)
}
}
baseURLAbs, _ := url.Parse(baseURL)
err = p.downloadAndAddFile(dn, baseURLAbs, "index.html", bar)
if err != nil {
return nil, fmt.Errorf("failed to download base HTML: %w", err)
downloadAndAdd(iconURL.String())
}
return dn, nil
@ -170,22 +176,28 @@ func (p *pwaClient) resolveURL(base, ref string) (*url.URL, error) {
return baseURL.ResolveReference(refURL), nil
}
func (p *pwaClient) downloadAndAddFile(dn *datanode.DataNode, fileURL *url.URL, internalPath string, bar *progressbar.ProgressBar) error {
resp, err := p.client.Get(fileURL.String())
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("bad status: %s", resp.Status)
// MockPWAClient is a mock implementation of the PWAClient interface.
type MockPWAClient struct {
ManifestURL string
DN *datanode.DataNode
Err error
}
data, err := io.ReadAll(resp.Body)
if err != nil {
return err
// NewMockPWAClient creates a new MockPWAClient.
func NewMockPWAClient(manifestURL string, dn *datanode.DataNode, err error) PWAClient {
return &MockPWAClient{
ManifestURL: manifestURL,
DN: dn,
Err: err,
}
dn.AddData(path.Clean(internalPath), data)
bar.Add(1)
return nil
}
// FindManifest mocks the finding of a PWA manifest.
func (m *MockPWAClient) FindManifest(pwaURL string) (string, error) {
return m.ManifestURL, m.Err
}
// DownloadAndPackagePWA mocks the downloading and packaging of a PWA.
func (m *MockPWAClient) DownloadAndPackagePWA(pwaURL, manifestURL string, bar *progressbar.ProgressBar) (*datanode.DataNode, error) {
return m.DN, m.Err
}

View file

@ -1,32 +1,15 @@
package pwa
import (
"net"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/schollz/progressbar/v3"
)
func newTestPWAClient(serverURL string) PWAClient {
return &pwaClient{
client: &http.Client{
Transport: &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}).DialContext,
ForceAttemptHTTP2: true,
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
},
},
}
return NewPWAClient()
}
func TestFindManifest(t *testing.T) {
@ -111,6 +94,7 @@ func TestDownloadAndPackagePWA(t *testing.T) {
expectedFiles := []string{"manifest.json", "index.html", "icon.png"}
for _, file := range expectedFiles {
// The path in the datanode is relative to the root of the domain, so we need to remove the leading slash.
exists, err := dn.Exists(file)
if err != nil {
t.Fatalf("Exists failed for %s: %v", file, err)
@ -122,7 +106,7 @@ func TestDownloadAndPackagePWA(t *testing.T) {
}
func TestResolveURL(t *testing.T) {
client := NewPWAClient()
client := NewPWAClient().(*pwaClient)
tests := []struct {
base string
ref string
@ -137,7 +121,7 @@ func TestResolveURL(t *testing.T) {
}
for _, tt := range tests {
got, err := client.(*pwaClient).resolveURL(tt.base, tt.ref)
got, err := client.resolveURL(tt.base, tt.ref)
if err != nil {
t.Errorf("resolveURL(%q, %q) returned error: %v", tt.base, tt.ref, err)
continue