Merge pull request #12 from Snider/feature-improve-test-coverage
feat: Improve test coverage and refactor for testability
This commit is contained in:
commit
02993fe87b
34 changed files with 1557 additions and 586 deletions
159
cmd/all.go
159
cmd/all.go
|
|
@ -2,66 +2,157 @@ package cmd
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"io"
|
||||
"io/fs"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/Snider/Borg/pkg/compress"
|
||||
"github.com/Snider/Borg/pkg/datanode"
|
||||
"github.com/Snider/Borg/pkg/github"
|
||||
"github.com/Snider/Borg/pkg/matrix"
|
||||
"github.com/Snider/Borg/pkg/ui"
|
||||
|
||||
"github.com/Snider/Borg/pkg/vcs"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
// allCmd represents the all command
|
||||
var allCmd = &cobra.Command{
|
||||
Use: "all [user/org]",
|
||||
Short: "Collect all public repositories from a user or organization",
|
||||
Long: `Collect all public repositories from a user or organization and store them in a DataNode.`,
|
||||
Use: "all [url]",
|
||||
Short: "Collect all resources from a URL",
|
||||
Long: `Collect all resources from a URL, dispatching to the appropriate collector based on the URL type.`,
|
||||
Args: cobra.ExactArgs(1),
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
logVal := cmd.Context().Value("logger")
|
||||
log, ok := logVal.(*slog.Logger)
|
||||
if !ok || log == nil {
|
||||
fmt.Fprintln(os.Stderr, "Error: logger not properly initialised")
|
||||
return
|
||||
}
|
||||
repos, err := GithubClient.GetPublicRepos(cmd.Context(), args[0])
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
url := args[0]
|
||||
outputFile, _ := cmd.Flags().GetString("output")
|
||||
format, _ := cmd.Flags().GetString("format")
|
||||
compression, _ := cmd.Flags().GetString("compression")
|
||||
|
||||
owner, err := parseGithubOwner(url)
|
||||
if err != nil {
|
||||
log.Error("failed to get public repos", "err", err)
|
||||
return
|
||||
return err
|
||||
}
|
||||
|
||||
outputDir, _ := cmd.Flags().GetString("output")
|
||||
repos, err := GithubClient.GetPublicRepos(cmd.Context(), owner)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
prompter := ui.NewNonInteractivePrompter(ui.GetVCSQuote)
|
||||
prompter.Start()
|
||||
defer prompter.Stop()
|
||||
|
||||
var progressWriter io.Writer
|
||||
if prompter.IsInteractive() {
|
||||
bar := ui.NewProgressBar(len(repos), "Cloning repositories")
|
||||
progressWriter = ui.NewProgressWriter(bar)
|
||||
}
|
||||
|
||||
cloner := vcs.NewGitCloner()
|
||||
allDataNodes := datanode.New()
|
||||
|
||||
for _, repoURL := range repos {
|
||||
log.Info("cloning repository", "url", repoURL)
|
||||
bar := ui.NewProgressBar(-1, "Cloning repository")
|
||||
|
||||
dn, err := GitCloner.CloneGitRepository(repoURL, bar)
|
||||
bar.Finish()
|
||||
dn, err := cloner.CloneGitRepository(repoURL, progressWriter)
|
||||
if err != nil {
|
||||
log.Error("failed to clone repository", "url", repoURL, "err", err)
|
||||
// Log the error and continue
|
||||
fmt.Fprintln(cmd.ErrOrStderr(), "Error cloning repository:", err)
|
||||
continue
|
||||
}
|
||||
// This is not an efficient way to merge datanodes, but it's the only way for now
|
||||
// A better approach would be to add a Merge method to the DataNode
|
||||
repoName := strings.TrimSuffix(repoURL, ".git")
|
||||
parts := strings.Split(repoName, "/")
|
||||
repoName = parts[len(parts)-1]
|
||||
|
||||
data, err := dn.ToTar()
|
||||
err = dn.Walk(".", func(path string, de fs.DirEntry, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !de.IsDir() {
|
||||
err := func() error {
|
||||
file, err := dn.Open(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer file.Close()
|
||||
data, err := io.ReadAll(file)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
allDataNodes.AddData(repoName+"/"+path, data)
|
||||
return nil
|
||||
}()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
log.Error("failed to serialize datanode", "url", repoURL, "err", err)
|
||||
continue
|
||||
}
|
||||
|
||||
repoName := strings.Split(repoURL, "/")[len(strings.Split(repoURL, "/"))-1]
|
||||
outputFile := fmt.Sprintf("%s/%s.dat", outputDir, repoName)
|
||||
err = os.WriteFile(outputFile, data, 0644)
|
||||
if err != nil {
|
||||
log.Error("failed to write datanode to file", "url", repoURL, "err", err)
|
||||
fmt.Fprintln(cmd.ErrOrStderr(), "Error walking datanode:", err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
var data []byte
|
||||
if format == "matrix" {
|
||||
matrix, err := matrix.FromDataNode(allDataNodes)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating matrix: %w", err)
|
||||
}
|
||||
data, err = matrix.ToTar()
|
||||
if err != nil {
|
||||
return fmt.Errorf("error serializing matrix: %w", err)
|
||||
}
|
||||
} else {
|
||||
data, err = allDataNodes.ToTar()
|
||||
if err != nil {
|
||||
return fmt.Errorf("error serializing DataNode: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
compressedData, err := compress.Compress(data, compression)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error compressing data: %w", err)
|
||||
}
|
||||
|
||||
err = os.WriteFile(outputFile, compressedData, 0644)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error writing DataNode to file: %w", err)
|
||||
}
|
||||
|
||||
fmt.Fprintln(cmd.OutOrStdout(), "All repositories saved to", outputFile)
|
||||
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
// init registers the 'all' command and its flags with the root command.
|
||||
func init() {
|
||||
RootCmd.AddCommand(allCmd)
|
||||
allCmd.PersistentFlags().String("output", ".", "Output directory for the DataNodes")
|
||||
allCmd.PersistentFlags().String("output", "all.dat", "Output file for the DataNode")
|
||||
allCmd.PersistentFlags().String("format", "datanode", "Output format (datanode or matrix)")
|
||||
allCmd.PersistentFlags().String("compression", "none", "Compression format (none, gz, or xz)")
|
||||
}
|
||||
|
||||
func parseGithubOwner(u string) (string, error) {
|
||||
owner, _, err := github.ParseRepoFromURL(u)
|
||||
if err == nil {
|
||||
return owner, nil
|
||||
}
|
||||
|
||||
parsedURL, err := url.Parse(u)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid URL: %w", err)
|
||||
}
|
||||
|
||||
path := strings.Trim(parsedURL.Path, "/")
|
||||
if path == "" {
|
||||
return "", fmt.Errorf("invalid owner URL: %s", u)
|
||||
}
|
||||
parts := strings.Split(path, "/")
|
||||
if len(parts) != 1 || parts[0] == "" {
|
||||
return "", fmt.Errorf("invalid owner URL: %s", u)
|
||||
}
|
||||
return parts[0], nil
|
||||
}
|
||||
|
|
|
|||
75
cmd/all_test.go
Normal file
75
cmd/all_test.go
Normal file
|
|
@ -0,0 +1,75 @@
|
|||
package cmd
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/Snider/Borg/pkg/datanode"
|
||||
"github.com/Snider/Borg/pkg/github"
|
||||
"github.com/Snider/Borg/pkg/mocks"
|
||||
)
|
||||
|
||||
func TestAllCmd_Good(t *testing.T) {
|
||||
mockGithubClient := mocks.NewMockClient(map[string]*http.Response{
|
||||
"https://api.github.com/users/testuser/repos": {
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
Body: io.NopCloser(bytes.NewBufferString(`[{"clone_url": "https://github.com/testuser/repo1.git"}]`)),
|
||||
},
|
||||
})
|
||||
oldNewAuthenticatedClient := github.NewAuthenticatedClient
|
||||
github.NewAuthenticatedClient = func(ctx context.Context) *http.Client {
|
||||
return mockGithubClient
|
||||
}
|
||||
defer func() {
|
||||
github.NewAuthenticatedClient = oldNewAuthenticatedClient
|
||||
}()
|
||||
|
||||
mockCloner := &mocks.MockGitCloner{
|
||||
DN: datanode.New(),
|
||||
Err: nil,
|
||||
}
|
||||
oldCloner := GitCloner
|
||||
GitCloner = mockCloner
|
||||
defer func() {
|
||||
GitCloner = oldCloner
|
||||
}()
|
||||
|
||||
rootCmd := NewRootCmd()
|
||||
rootCmd.AddCommand(allCmd)
|
||||
|
||||
out := filepath.Join(t.TempDir(), "out")
|
||||
_, err := executeCommand(rootCmd, "all", "https://github.com/testuser", "--output", out)
|
||||
if err != nil {
|
||||
t.Fatalf("all command failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAllCmd_Bad(t *testing.T) {
|
||||
mockGithubClient := mocks.NewMockClient(map[string]*http.Response{
|
||||
"https://api.github.com/users/testuser/repos": {
|
||||
StatusCode: http.StatusNotFound,
|
||||
Body: io.NopCloser(bytes.NewBufferString(`{"message": "Not Found"}`)),
|
||||
},
|
||||
})
|
||||
oldNewAuthenticatedClient := github.NewAuthenticatedClient
|
||||
github.NewAuthenticatedClient = func(ctx context.Context) *http.Client {
|
||||
return mockGithubClient
|
||||
}
|
||||
defer func() {
|
||||
github.NewAuthenticatedClient = oldNewAuthenticatedClient
|
||||
}()
|
||||
|
||||
rootCmd := NewRootCmd()
|
||||
rootCmd.AddCommand(allCmd)
|
||||
|
||||
out := filepath.Join(t.TempDir(), "out")
|
||||
_, err := executeCommand(rootCmd, "all", "https://github.com/testuser", "--output", out)
|
||||
if err == nil {
|
||||
t.Fatalf("expected an error, but got none")
|
||||
}
|
||||
}
|
||||
|
|
@ -7,11 +7,13 @@ import (
|
|||
// collectCmd represents the collect command
|
||||
var collectCmd = &cobra.Command{
|
||||
Use: "collect",
|
||||
Short: "Collect a resource and store it in a DataNode.",
|
||||
Long: `Collect a resource from a git repository, a website, or other URI and store it in a DataNode.`,
|
||||
Short: "Collect a resource from a URI.",
|
||||
Long: `Collect a resource from a URI and store it in a DataNode.`,
|
||||
}
|
||||
|
||||
// init registers the collect command with the root.
|
||||
func init() {
|
||||
RootCmd.AddCommand(collectCmd)
|
||||
}
|
||||
func NewCollectCmd() *cobra.Command {
|
||||
return collectCmd
|
||||
}
|
||||
|
|
|
|||
|
|
@ -11,7 +11,9 @@ var collectGithubCmd = &cobra.Command{
|
|||
Long: `Collect a resource from a GitHub repository, such as a repository or a release.`,
|
||||
}
|
||||
|
||||
// init registers the GitHub collection parent command.
|
||||
func init() {
|
||||
collectCmd.AddCommand(collectGithubCmd)
|
||||
}
|
||||
func GetCollectGithubCmd() *cobra.Command {
|
||||
return collectGithubCmd
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,136 +1,151 @@
|
|||
package cmd
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/Snider/Borg/pkg/datanode"
|
||||
borg_github "github.com/Snider/Borg/pkg/github"
|
||||
gh "github.com/google/go-github/v39/github"
|
||||
"github.com/google/go-github/v39/github"
|
||||
"github.com/spf13/cobra"
|
||||
"golang.org/x/mod/semver"
|
||||
)
|
||||
|
||||
// collectGithubReleaseCmd represents the collect github-release command
|
||||
var collectGithubReleaseCmd = &cobra.Command{
|
||||
Use: "release [repository-url]",
|
||||
Short: "Download the latest release of a file from GitHub releases",
|
||||
Long: `Download the latest release of a file from GitHub releases. If the file or URL has a version number, it will check for a higher version and download it if found.`,
|
||||
Args: cobra.ExactArgs(1),
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
logVal := cmd.Context().Value("logger")
|
||||
log, ok := logVal.(*slog.Logger)
|
||||
if !ok || log == nil {
|
||||
fmt.Fprintln(os.Stderr, "Error: logger not properly initialised")
|
||||
return
|
||||
}
|
||||
repoURL := args[0]
|
||||
outputDir, _ := cmd.Flags().GetString("output")
|
||||
pack, _ := cmd.Flags().GetBool("pack")
|
||||
file, _ := cmd.Flags().GetString("file")
|
||||
version, _ := cmd.Flags().GetString("version")
|
||||
func NewCollectGithubReleaseCmd() *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "release [repository-url]",
|
||||
Short: "Download the latest release of a file from GitHub releases",
|
||||
Long: `Download the latest release of a file from GitHub releases. If the file or URL has a version number, it will check for a higher version and download it if found.`,
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
logVal := cmd.Context().Value("logger")
|
||||
log, ok := logVal.(*slog.Logger)
|
||||
if !ok || log == nil {
|
||||
return errors.New("logger not properly initialised")
|
||||
}
|
||||
repoURL := args[0]
|
||||
outputDir, _ := cmd.Flags().GetString("output")
|
||||
pack, _ := cmd.Flags().GetBool("pack")
|
||||
file, _ := cmd.Flags().GetString("file")
|
||||
version, _ := cmd.Flags().GetString("version")
|
||||
|
||||
owner, repo, err := borg_github.ParseRepoFromURL(repoURL)
|
||||
if err != nil {
|
||||
log.Error("failed to parse repository url", "err", err)
|
||||
return
|
||||
}
|
||||
|
||||
release, err := borg_github.GetLatestRelease(owner, repo)
|
||||
if err != nil {
|
||||
log.Error("failed to get latest release", "err", err)
|
||||
return
|
||||
}
|
||||
|
||||
log.Info("found latest release", "tag", release.GetTagName())
|
||||
|
||||
if version != "" {
|
||||
if !semver.IsValid(version) {
|
||||
log.Error("invalid version string", "version", version)
|
||||
return
|
||||
}
|
||||
if semver.Compare(release.GetTagName(), version) <= 0 {
|
||||
log.Info("latest release is not newer than the provided version", "latest", release.GetTagName(), "provided", version)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if pack {
|
||||
dn := datanode.New()
|
||||
for _, asset := range release.Assets {
|
||||
log.Info("downloading asset", "name", asset.GetName())
|
||||
resp, err := http.Get(asset.GetBrowserDownloadURL())
|
||||
if err != nil {
|
||||
log.Error("failed to download asset", "name", asset.GetName(), "err", err)
|
||||
continue
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
var buf bytes.Buffer
|
||||
_, err = io.Copy(&buf, resp.Body)
|
||||
if err != nil {
|
||||
log.Error("failed to read asset", "name", asset.GetName(), "err", err)
|
||||
continue
|
||||
}
|
||||
dn.AddData(asset.GetName(), buf.Bytes())
|
||||
}
|
||||
tar, err := dn.ToTar()
|
||||
if err != nil {
|
||||
log.Error("failed to create datanode", "err", err)
|
||||
return
|
||||
}
|
||||
outputFile := outputDir
|
||||
if !strings.HasSuffix(outputFile, ".dat") {
|
||||
outputFile = outputFile + ".dat"
|
||||
}
|
||||
err = os.WriteFile(outputFile, tar, 0644)
|
||||
if err != nil {
|
||||
log.Error("failed to write datanode", "err", err)
|
||||
return
|
||||
}
|
||||
log.Info("datanode saved", "path", outputFile)
|
||||
} else {
|
||||
if len(release.Assets) == 0 {
|
||||
log.Info("no assets found in the latest release")
|
||||
return
|
||||
}
|
||||
var assetToDownload *gh.ReleaseAsset
|
||||
if file != "" {
|
||||
for _, asset := range release.Assets {
|
||||
if asset.GetName() == file {
|
||||
assetToDownload = asset
|
||||
break
|
||||
}
|
||||
}
|
||||
if assetToDownload == nil {
|
||||
log.Error("asset not found in the latest release", "asset", file)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
assetToDownload = release.Assets[0]
|
||||
}
|
||||
outputPath := filepath.Join(outputDir, assetToDownload.GetName())
|
||||
log.Info("downloading asset", "name", assetToDownload.GetName())
|
||||
err = borg_github.DownloadReleaseAsset(assetToDownload, outputPath)
|
||||
if err != nil {
|
||||
log.Error("failed to download asset", "name", assetToDownload.GetName(), "err", err)
|
||||
return
|
||||
}
|
||||
log.Info("asset downloaded", "path", outputPath)
|
||||
}
|
||||
},
|
||||
_, err := GetRelease(log, repoURL, outputDir, pack, file, version)
|
||||
return err
|
||||
},
|
||||
}
|
||||
cmd.PersistentFlags().String("output", ".", "Output directory for the downloaded file")
|
||||
cmd.PersistentFlags().Bool("pack", false, "Pack all assets into a DataNode")
|
||||
cmd.PersistentFlags().String("file", "", "The file to download from the release")
|
||||
cmd.PersistentFlags().String("version", "", "The version to check against")
|
||||
return cmd
|
||||
}
|
||||
|
||||
// init registers the 'collect github release' subcommand and its flags.
|
||||
func init() {
|
||||
collectGithubCmd.AddCommand(collectGithubReleaseCmd)
|
||||
collectGithubReleaseCmd.PersistentFlags().String("output", ".", "Output directory for the downloaded file")
|
||||
collectGithubReleaseCmd.PersistentFlags().Bool("pack", false, "Pack all assets into a DataNode")
|
||||
collectGithubReleaseCmd.PersistentFlags().String("file", "", "The file to download from the release")
|
||||
collectGithubReleaseCmd.PersistentFlags().String("version", "", "The version to check against")
|
||||
collectGithubCmd.AddCommand(NewCollectGithubReleaseCmd())
|
||||
}
|
||||
|
||||
func GetRelease(log *slog.Logger, repoURL string, outputDir string, pack bool, file string, version string) (*github.RepositoryRelease, error) {
|
||||
owner, repo, err := borg_github.ParseRepoFromURL(repoURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse repository url: %w", err)
|
||||
}
|
||||
|
||||
release, err := borg_github.GetLatestRelease(owner, repo)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get latest release: %w", err)
|
||||
}
|
||||
|
||||
log.Info("found latest release", "tag", release.GetTagName())
|
||||
|
||||
if version != "" {
|
||||
tag := release.GetTagName()
|
||||
if !semver.IsValid(tag) {
|
||||
log.Info("latest release tag is not a valid semantic version, skipping comparison", "tag", tag)
|
||||
} else {
|
||||
if !semver.IsValid(version) {
|
||||
return nil, fmt.Errorf("invalid version string: %s", version)
|
||||
}
|
||||
if semver.Compare(tag, version) <= 0 {
|
||||
log.Info("latest release is not newer than the provided version", "latest", tag, "provided", version)
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if pack {
|
||||
dn := datanode.New()
|
||||
var failedAssets []string
|
||||
for _, asset := range release.Assets {
|
||||
log.Info("downloading asset", "name", asset.GetName())
|
||||
data, err := borg_github.DownloadReleaseAsset(asset)
|
||||
if err != nil {
|
||||
log.Error("failed to download asset", "name", asset.GetName(), "err", err)
|
||||
failedAssets = append(failedAssets, asset.GetName())
|
||||
continue
|
||||
}
|
||||
dn.AddData(asset.GetName(), data)
|
||||
}
|
||||
if len(failedAssets) > 0 {
|
||||
return nil, fmt.Errorf("failed to download assets: %v", failedAssets)
|
||||
}
|
||||
|
||||
tar, err := dn.ToTar()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create datanode: %w", err)
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(outputDir, 0755); err != nil {
|
||||
return nil, fmt.Errorf("failed to create output directory: %w", err)
|
||||
}
|
||||
basename := release.GetTagName()
|
||||
if basename == "" {
|
||||
basename = "release"
|
||||
}
|
||||
outputFile := filepath.Join(outputDir, basename+".dat")
|
||||
|
||||
err = os.WriteFile(outputFile, tar, 0644)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to write datanode: %w", err)
|
||||
}
|
||||
log.Info("datanode saved", "path", outputFile)
|
||||
} else {
|
||||
if len(release.Assets) == 0 {
|
||||
log.Info("no assets found in the latest release")
|
||||
return nil, nil
|
||||
}
|
||||
var assetToDownload *github.ReleaseAsset
|
||||
if file != "" {
|
||||
for _, asset := range release.Assets {
|
||||
if asset.GetName() == file {
|
||||
assetToDownload = asset
|
||||
break
|
||||
}
|
||||
}
|
||||
if assetToDownload == nil {
|
||||
return nil, fmt.Errorf("asset not found in the latest release: %s", file)
|
||||
}
|
||||
} else {
|
||||
assetToDownload = release.Assets[0]
|
||||
}
|
||||
if outputDir != "" {
|
||||
if err := os.MkdirAll(outputDir, 0755); err != nil {
|
||||
return nil, fmt.Errorf("failed to create output directory: %w", err)
|
||||
}
|
||||
}
|
||||
outputPath := filepath.Join(outputDir, assetToDownload.GetName())
|
||||
log.Info("downloading asset", "name", assetToDownload.GetName())
|
||||
data, err := borg_github.DownloadReleaseAsset(assetToDownload)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to download asset: %w", err)
|
||||
}
|
||||
err = os.WriteFile(outputPath, data, 0644)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to write asset to file: %w", err)
|
||||
}
|
||||
log.Info("asset downloaded", "path", outputPath)
|
||||
}
|
||||
return release, nil
|
||||
}
|
||||
|
|
|
|||
110
cmd/collect_github_release_subcommand_test.go
Normal file
110
cmd/collect_github_release_subcommand_test.go
Normal file
|
|
@ -0,0 +1,110 @@
|
|||
package cmd
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/Snider/Borg/pkg/mocks"
|
||||
borg_github "github.com/Snider/Borg/pkg/github"
|
||||
"github.com/google/go-github/v39/github"
|
||||
)
|
||||
|
||||
func TestGetRelease_Good(t *testing.T) {
|
||||
// Create a temporary directory for the output
|
||||
dir, err := os.MkdirTemp("", "test-get-release")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(dir)
|
||||
|
||||
mockClient := mocks.NewMockClient(map[string]*http.Response{
|
||||
"https://api.github.com/repos/owner/repo/releases/latest": {
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(bytes.NewBufferString(`{"tag_name": "v1.0.0", "assets": [{"name": "asset1.zip", "browser_download_url": "https://github.com/owner/repo/releases/download/v1.0.0/asset1.zip"}]}`)),
|
||||
},
|
||||
"https://github.com/owner/repo/releases/download/v1.0.0/asset1.zip": {
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(bytes.NewBufferString("asset content")),
|
||||
},
|
||||
})
|
||||
|
||||
oldNewClient := borg_github.NewClient
|
||||
borg_github.NewClient = func(httpClient *http.Client) *github.Client {
|
||||
return github.NewClient(mockClient)
|
||||
}
|
||||
defer func() {
|
||||
borg_github.NewClient = oldNewClient
|
||||
}()
|
||||
|
||||
oldDefaultClient := borg_github.DefaultClient
|
||||
borg_github.DefaultClient = mockClient
|
||||
defer func() {
|
||||
borg_github.DefaultClient = oldDefaultClient
|
||||
}()
|
||||
|
||||
log := slog.New(slog.NewJSONHandler(io.Discard, nil))
|
||||
|
||||
// Test downloading a single asset
|
||||
_, err = GetRelease(log, "https://github.com/owner/repo", dir, false, "asset1.zip", "")
|
||||
if err != nil {
|
||||
t.Fatalf("GetRelease failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify the asset was downloaded
|
||||
content, err := os.ReadFile(filepath.Join(dir, "asset1.zip"))
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read downloaded asset: %v", err)
|
||||
}
|
||||
if string(content) != "asset content" {
|
||||
t.Errorf("unexpected asset content: %s", string(content))
|
||||
}
|
||||
|
||||
// Test packing all assets
|
||||
packedDir := filepath.Join(dir, "packed")
|
||||
_, err = GetRelease(log, "https://github.com/owner/repo", packedDir, true, "", "")
|
||||
if err != nil {
|
||||
t.Fatalf("GetRelease with --pack failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify the datanode was created
|
||||
if _, err := os.Stat(filepath.Join(packedDir, "v1.0.0.dat")); os.IsNotExist(err) {
|
||||
t.Fatalf("datanode not created")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetRelease_Bad(t *testing.T) {
|
||||
// Create a temporary directory for the output
|
||||
dir, err := os.MkdirTemp("", "test-get-release")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(dir)
|
||||
|
||||
mockClient := mocks.NewMockClient(map[string]*http.Response{
|
||||
"https://api.github.com/repos/owner/repo/releases/latest": {
|
||||
StatusCode: http.StatusNotFound,
|
||||
Body: io.NopCloser(bytes.NewBufferString(`{"message": "Not Found"}`)),
|
||||
},
|
||||
})
|
||||
|
||||
oldNewClient := borg_github.NewClient
|
||||
borg_github.NewClient = func(httpClient *http.Client) *github.Client {
|
||||
return github.NewClient(mockClient)
|
||||
}
|
||||
defer func() {
|
||||
borg_github.NewClient = oldNewClient
|
||||
}()
|
||||
|
||||
log := slog.New(slog.NewJSONHandler(io.Discard, nil))
|
||||
|
||||
// Test failed release lookup
|
||||
_, err = GetRelease(log, "https://github.com/owner/repo", dir, false, "", "")
|
||||
if err == nil {
|
||||
t.Fatalf("expected an error, but got none")
|
||||
}
|
||||
}
|
||||
|
|
@ -13,86 +13,94 @@ import (
|
|||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultFilePermission = 0644
|
||||
)
|
||||
|
||||
var (
|
||||
// GitCloner is the git cloner used by the command. It can be replaced for testing.
|
||||
GitCloner = vcs.NewGitCloner()
|
||||
)
|
||||
|
||||
// collectGithubRepoCmd represents the collect github repo command
|
||||
var collectGithubRepoCmd = &cobra.Command{
|
||||
Use: "repo [repository-url]",
|
||||
Short: "Collect a single Git repository",
|
||||
Long: `Collect a single Git repository and store it in a DataNode.`,
|
||||
Args: cobra.ExactArgs(1),
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
repoURL := args[0]
|
||||
outputFile, _ := cmd.Flags().GetString("output")
|
||||
format, _ := cmd.Flags().GetString("format")
|
||||
compression, _ := cmd.Flags().GetString("compression")
|
||||
// NewCollectGithubRepoCmd creates a new cobra command for collecting a single git repository.
|
||||
func NewCollectGithubRepoCmd() *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "repo [repository-url]",
|
||||
Short: "Collect a single Git repository",
|
||||
Long: `Collect a single Git repository and store it in a DataNode.`,
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
repoURL := args[0]
|
||||
outputFile, _ := cmd.Flags().GetString("output")
|
||||
format, _ := cmd.Flags().GetString("format")
|
||||
compression, _ := cmd.Flags().GetString("compression")
|
||||
|
||||
prompter := ui.NewNonInteractivePrompter(ui.GetVCSQuote)
|
||||
prompter.Start()
|
||||
defer prompter.Stop()
|
||||
if format != "datanode" && format != "matrix" {
|
||||
return fmt.Errorf("invalid format: %s (must be 'datanode' or 'matrix')", format)
|
||||
}
|
||||
if compression != "none" && compression != "gz" && compression != "xz" {
|
||||
return fmt.Errorf("invalid compression: %s (must be 'none', 'gz', or 'xz')", compression)
|
||||
}
|
||||
|
||||
var progressWriter io.Writer
|
||||
if prompter.IsInteractive() {
|
||||
bar := ui.NewProgressBar(-1, "Cloning repository")
|
||||
progressWriter = ui.NewProgressWriter(bar)
|
||||
}
|
||||
prompter := ui.NewNonInteractivePrompter(ui.GetVCSQuote)
|
||||
prompter.Start()
|
||||
defer prompter.Stop()
|
||||
|
||||
dn, err := GitCloner.CloneGitRepository(repoURL, progressWriter)
|
||||
if err != nil {
|
||||
fmt.Fprintln(cmd.ErrOrStderr(), "Error cloning repository:", err)
|
||||
return
|
||||
}
|
||||
var progressWriter io.Writer
|
||||
if prompter.IsInteractive() {
|
||||
bar := ui.NewProgressBar(-1, "Cloning repository")
|
||||
progressWriter = ui.NewProgressWriter(bar)
|
||||
}
|
||||
|
||||
var data []byte
|
||||
if format == "matrix" {
|
||||
matrix, err := matrix.FromDataNode(dn)
|
||||
dn, err := GitCloner.CloneGitRepository(repoURL, progressWriter)
|
||||
if err != nil {
|
||||
fmt.Fprintln(cmd.ErrOrStderr(), "Error creating matrix:", err)
|
||||
return
|
||||
return fmt.Errorf("error cloning repository: %w", err)
|
||||
}
|
||||
data, err = matrix.ToTar()
|
||||
|
||||
var data []byte
|
||||
if format == "matrix" {
|
||||
matrix, err := matrix.FromDataNode(dn)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating matrix: %w", err)
|
||||
}
|
||||
data, err = matrix.ToTar()
|
||||
if err != nil {
|
||||
return fmt.Errorf("error serializing matrix: %w", err)
|
||||
}
|
||||
} else {
|
||||
data, err = dn.ToTar()
|
||||
if err != nil {
|
||||
return fmt.Errorf("error serializing DataNode: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
compressedData, err := compress.Compress(data, compression)
|
||||
if err != nil {
|
||||
fmt.Fprintln(cmd.ErrOrStderr(), "Error serializing matrix:", err)
|
||||
return
|
||||
return fmt.Errorf("error compressing data: %w", err)
|
||||
}
|
||||
} else {
|
||||
data, err = dn.ToTar()
|
||||
|
||||
if outputFile == "" {
|
||||
outputFile = "repo." + format
|
||||
if compression != "none" {
|
||||
outputFile += "." + compression
|
||||
}
|
||||
}
|
||||
|
||||
err = os.WriteFile(outputFile, compressedData, defaultFilePermission)
|
||||
if err != nil {
|
||||
fmt.Fprintln(cmd.ErrOrStderr(), "Error serializing DataNode:", err)
|
||||
return
|
||||
return fmt.Errorf("error writing DataNode to file: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
compressedData, err := compress.Compress(data, compression)
|
||||
if err != nil {
|
||||
fmt.Fprintln(cmd.ErrOrStderr(), "Error compressing data:", err)
|
||||
return
|
||||
}
|
||||
|
||||
if outputFile == "" {
|
||||
outputFile = "repo." + format
|
||||
if compression != "none" {
|
||||
outputFile += "." + compression
|
||||
}
|
||||
}
|
||||
|
||||
err = os.WriteFile(outputFile, compressedData, 0644)
|
||||
if err != nil {
|
||||
fmt.Fprintln(cmd.ErrOrStderr(), "Error writing DataNode to file:", err)
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Fprintln(cmd.OutOrStdout(), "Repository saved to", outputFile)
|
||||
},
|
||||
fmt.Fprintln(cmd.OutOrStdout(), "Repository saved to", outputFile)
|
||||
return nil
|
||||
},
|
||||
}
|
||||
cmd.Flags().String("output", "", "Output file for the DataNode")
|
||||
cmd.Flags().String("format", "datanode", "Output format (datanode or matrix)")
|
||||
cmd.Flags().String("compression", "none", "Compression format (none, gz, or xz)")
|
||||
return cmd
|
||||
}
|
||||
|
||||
// init registers the 'collect github repo' subcommand and its flags.
|
||||
func init() {
|
||||
collectGithubCmd.AddCommand(collectGithubRepoCmd)
|
||||
collectGithubRepoCmd.PersistentFlags().String("output", "", "Output file for the DataNode")
|
||||
collectGithubRepoCmd.PersistentFlags().String("format", "datanode", "Output format (datanode or matrix)")
|
||||
collectGithubRepoCmd.PersistentFlags().String("compression", "none", "Compression format (none, gz, or xz)")
|
||||
collectGithubCmd.AddCommand(NewCollectGithubRepoCmd())
|
||||
}
|
||||
|
|
|
|||
52
cmd/collect_github_repo_test.go
Normal file
52
cmd/collect_github_repo_test.go
Normal file
|
|
@ -0,0 +1,52 @@
|
|||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/Snider/Borg/pkg/datanode"
|
||||
"github.com/Snider/Borg/pkg/mocks"
|
||||
)
|
||||
|
||||
func TestCollectGithubRepoCmd_Good(t *testing.T) {
|
||||
mockCloner := &mocks.MockGitCloner{
|
||||
DN: datanode.New(),
|
||||
Err: nil,
|
||||
}
|
||||
oldCloner := GitCloner
|
||||
GitCloner = mockCloner
|
||||
defer func() {
|
||||
GitCloner = oldCloner
|
||||
}()
|
||||
|
||||
rootCmd := NewRootCmd()
|
||||
rootCmd.AddCommand(collectCmd)
|
||||
|
||||
out := filepath.Join(t.TempDir(), "out")
|
||||
_, err := executeCommand(rootCmd, "collect", "github", "repo", "https://github.com/testuser/repo1.git", "--output", out)
|
||||
if err != nil {
|
||||
t.Fatalf("collect github repo command failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCollectGithubRepoCmd_Bad(t *testing.T) {
|
||||
mockCloner := &mocks.MockGitCloner{
|
||||
DN: nil,
|
||||
Err: fmt.Errorf("git clone error"),
|
||||
}
|
||||
oldCloner := GitCloner
|
||||
GitCloner = mockCloner
|
||||
defer func() {
|
||||
GitCloner = oldCloner
|
||||
}()
|
||||
|
||||
rootCmd := NewRootCmd()
|
||||
rootCmd.AddCommand(collectCmd)
|
||||
|
||||
out := filepath.Join(t.TempDir(), "out")
|
||||
_, err := executeCommand(rootCmd, "collect", "github", "repo", "https://github.com/testuser/repo1.git", "--output", out)
|
||||
if err == nil {
|
||||
t.Fatalf("expected an error, but got none")
|
||||
}
|
||||
}
|
||||
|
|
@ -28,7 +28,6 @@ var collectGithubReposCmd = &cobra.Command{
|
|||
},
|
||||
}
|
||||
|
||||
// init registers the 'collect github repos' subcommand.
|
||||
func init() {
|
||||
collectGithubCmd.AddCommand(collectGithubReposCmd)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -12,93 +12,103 @@ import (
|
|||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
var (
|
||||
// PWAClient is the pwa client used by the command. It can be replaced for testing.
|
||||
PWAClient = pwa.NewPWAClient()
|
||||
)
|
||||
type CollectPWACmd struct {
|
||||
cobra.Command
|
||||
PWAClient pwa.PWAClient
|
||||
}
|
||||
|
||||
// collectPWACmd represents the collect pwa command
|
||||
var collectPWACmd = &cobra.Command{
|
||||
Use: "pwa",
|
||||
Short: "Collect a single PWA using a URI",
|
||||
Long: `Collect a single PWA and store it in a DataNode.
|
||||
// NewCollectPWACmd creates a new collect pwa command
|
||||
func NewCollectPWACmd() *CollectPWACmd {
|
||||
c := &CollectPWACmd{
|
||||
PWAClient: pwa.NewPWAClient(),
|
||||
}
|
||||
c.Command = cobra.Command{
|
||||
Use: "pwa",
|
||||
Short: "Collect a single PWA using a URI",
|
||||
Long: `Collect a single PWA and store it in a DataNode.
|
||||
|
||||
Example:
|
||||
borg collect pwa --uri https://example.com --output mypwa.dat`,
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
pwaURL, _ := cmd.Flags().GetString("uri")
|
||||
outputFile, _ := cmd.Flags().GetString("output")
|
||||
format, _ := cmd.Flags().GetString("format")
|
||||
compression, _ := cmd.Flags().GetString("compression")
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
pwaURL, _ := cmd.Flags().GetString("uri")
|
||||
outputFile, _ := cmd.Flags().GetString("output")
|
||||
format, _ := cmd.Flags().GetString("format")
|
||||
compression, _ := cmd.Flags().GetString("compression")
|
||||
|
||||
if pwaURL == "" {
|
||||
fmt.Fprintln(cmd.ErrOrStderr(), "Error: uri is required")
|
||||
return
|
||||
}
|
||||
|
||||
bar := ui.NewProgressBar(-1, "Finding PWA manifest")
|
||||
defer bar.Finish()
|
||||
|
||||
manifestURL, err := PWAClient.FindManifest(pwaURL)
|
||||
if err != nil {
|
||||
fmt.Fprintln(cmd.ErrOrStderr(), "Error finding manifest:", err)
|
||||
return
|
||||
}
|
||||
bar.Describe("Downloading and packaging PWA")
|
||||
dn, err := PWAClient.DownloadAndPackagePWA(pwaURL, manifestURL, bar)
|
||||
if err != nil {
|
||||
fmt.Fprintln(cmd.ErrOrStderr(), "Error downloading and packaging PWA:", err)
|
||||
return
|
||||
}
|
||||
|
||||
var data []byte
|
||||
if format == "matrix" {
|
||||
matrix, err := matrix.FromDataNode(dn)
|
||||
finalPath, err := CollectPWA(c.PWAClient, pwaURL, outputFile, format, compression)
|
||||
if err != nil {
|
||||
fmt.Fprintln(cmd.ErrOrStderr(), "Error creating matrix:", err)
|
||||
return
|
||||
return err
|
||||
}
|
||||
data, err = matrix.ToTar()
|
||||
if err != nil {
|
||||
fmt.Fprintln(cmd.ErrOrStderr(), "Error serializing matrix:", err)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
data, err = dn.ToTar()
|
||||
if err != nil {
|
||||
fmt.Fprintln(cmd.ErrOrStderr(), "Error serializing DataNode:", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
compressedData, err := compress.Compress(data, compression)
|
||||
if err != nil {
|
||||
fmt.Fprintln(cmd.ErrOrStderr(), "Error compressing data:", err)
|
||||
return
|
||||
}
|
||||
|
||||
if outputFile == "" {
|
||||
outputFile = "pwa." + format
|
||||
if compression != "none" {
|
||||
outputFile += "." + compression
|
||||
}
|
||||
}
|
||||
|
||||
err = os.WriteFile(outputFile, compressedData, 0644)
|
||||
if err != nil {
|
||||
fmt.Fprintln(cmd.ErrOrStderr(), "Error writing PWA to file:", err)
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Fprintln(cmd.OutOrStdout(), "PWA saved to", outputFile)
|
||||
},
|
||||
fmt.Fprintln(cmd.OutOrStdout(), "PWA saved to", finalPath)
|
||||
return nil
|
||||
},
|
||||
}
|
||||
c.Flags().String("uri", "", "The URI of the PWA to collect")
|
||||
c.Flags().String("output", "", "Output file for the DataNode")
|
||||
c.Flags().String("format", "datanode", "Output format (datanode or matrix)")
|
||||
c.Flags().String("compression", "none", "Compression format (none, gz, or xz)")
|
||||
return c
|
||||
}
|
||||
|
||||
// init registers the 'collect pwa' subcommand and its flags.
|
||||
func init() {
|
||||
collectCmd.AddCommand(collectPWACmd)
|
||||
collectPWACmd.Flags().String("uri", "", "The URI of the PWA to collect")
|
||||
collectPWACmd.Flags().String("output", "", "Output file for the DataNode")
|
||||
collectPWACmd.Flags().String("format", "datanode", "Output format (datanode or matrix)")
|
||||
collectPWACmd.Flags().String("compression", "none", "Compression format (none, gz, or xz)")
|
||||
collectCmd.AddCommand(&NewCollectPWACmd().Command)
|
||||
}
|
||||
func CollectPWA(client pwa.PWAClient, pwaURL string, outputFile string, format string, compression string) (string, error) {
|
||||
if pwaURL == "" {
|
||||
return "", fmt.Errorf("uri is required")
|
||||
}
|
||||
if format != "datanode" && format != "matrix" {
|
||||
return "", fmt.Errorf("invalid format: %s (must be 'datanode' or 'matrix')", format)
|
||||
}
|
||||
if compression != "none" && compression != "gz" && compression != "xz" {
|
||||
return "", fmt.Errorf("invalid compression: %s (must be 'none', 'gz', or 'xz')", compression)
|
||||
}
|
||||
|
||||
bar := ui.NewProgressBar(-1, "Finding PWA manifest")
|
||||
defer bar.Finish()
|
||||
|
||||
manifestURL, err := client.FindManifest(pwaURL)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error finding manifest: %w", err)
|
||||
}
|
||||
bar.Describe("Downloading and packaging PWA")
|
||||
dn, err := client.DownloadAndPackagePWA(pwaURL, manifestURL, bar)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error downloading and packaging PWA: %w", err)
|
||||
}
|
||||
|
||||
var data []byte
|
||||
if format == "matrix" {
|
||||
matrix, err := matrix.FromDataNode(dn)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error creating matrix: %w", err)
|
||||
}
|
||||
data, err = matrix.ToTar()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error serializing matrix: %w", err)
|
||||
}
|
||||
} else {
|
||||
data, err = dn.ToTar()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error serializing DataNode: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
compressedData, err := compress.Compress(data, compression)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error compressing data: %w", err)
|
||||
}
|
||||
|
||||
if outputFile == "" {
|
||||
outputFile = "pwa." + format
|
||||
if compression != "none" {
|
||||
outputFile += "." + compression
|
||||
}
|
||||
}
|
||||
|
||||
err = os.WriteFile(outputFile, compressedData, 0644)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error writing PWA to file: %w", err)
|
||||
}
|
||||
return outputFile, nil
|
||||
}
|
||||
|
|
|
|||
61
cmd/collect_pwa_test.go
Normal file
61
cmd/collect_pwa_test.go
Normal file
|
|
@ -0,0 +1,61 @@
|
|||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/Snider/Borg/pkg/datanode"
|
||||
"github.com/Snider/Borg/pkg/pwa"
|
||||
)
|
||||
|
||||
func TestCollectPWACmd_NoURI(t *testing.T) {
|
||||
rootCmd := NewRootCmd()
|
||||
collectCmd := NewCollectCmd()
|
||||
collectPWACmd := NewCollectPWACmd()
|
||||
collectCmd.AddCommand(&collectPWACmd.Command)
|
||||
rootCmd.AddCommand(collectCmd)
|
||||
_, err := executeCommand(rootCmd, "collect", "pwa")
|
||||
if err == nil {
|
||||
t.Fatalf("expected an error, but got none")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "uri is required") {
|
||||
t.Fatalf("unexpected error message: %v", err)
|
||||
}
|
||||
}
|
||||
func Test_NewCollectPWACmd(t *testing.T) {
|
||||
if NewCollectPWACmd() == nil {
|
||||
t.Errorf("NewCollectPWACmd is nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCollectPWA_Good(t *testing.T) {
|
||||
mockClient := &pwa.MockPWAClient{
|
||||
ManifestURL: "https://example.com/manifest.json",
|
||||
DN: datanode.New(),
|
||||
Err: nil,
|
||||
}
|
||||
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "pwa.dat")
|
||||
_, err := CollectPWA(mockClient, "https://example.com", path, "datanode", "none")
|
||||
if err != nil {
|
||||
t.Fatalf("CollectPWA failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCollectPWA_Bad(t *testing.T) {
|
||||
mockClient := &pwa.MockPWAClient{
|
||||
ManifestURL: "",
|
||||
DN: nil,
|
||||
Err: fmt.Errorf("pwa error"),
|
||||
}
|
||||
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "pwa.dat")
|
||||
_, err := CollectPWA(mockClient, "https://example.com", path, "datanode", "none")
|
||||
if err == nil {
|
||||
t.Fatalf("expected an error, but got none")
|
||||
}
|
||||
}
|
||||
|
|
@ -4,11 +4,11 @@ import (
|
|||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/schollz/progressbar/v3"
|
||||
"github.com/Snider/Borg/pkg/compress"
|
||||
"github.com/Snider/Borg/pkg/matrix"
|
||||
"github.com/Snider/Borg/pkg/ui"
|
||||
"github.com/Snider/Borg/pkg/website"
|
||||
"github.com/schollz/progressbar/v3"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
|
@ -19,7 +19,7 @@ var collectWebsiteCmd = &cobra.Command{
|
|||
Short: "Collect a single website",
|
||||
Long: `Collect a single website and store it in a DataNode.`,
|
||||
Args: cobra.ExactArgs(1),
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
websiteURL := args[0]
|
||||
outputFile, _ := cmd.Flags().GetString("output")
|
||||
depth, _ := cmd.Flags().GetInt("depth")
|
||||
|
|
@ -36,34 +36,29 @@ var collectWebsiteCmd = &cobra.Command{
|
|||
|
||||
dn, err := website.DownloadAndPackageWebsite(websiteURL, depth, bar)
|
||||
if err != nil {
|
||||
fmt.Fprintln(cmd.ErrOrStderr(), "Error downloading and packaging website:", err)
|
||||
return
|
||||
return fmt.Errorf("error downloading and packaging website: %w", err)
|
||||
}
|
||||
|
||||
var data []byte
|
||||
if format == "matrix" {
|
||||
matrix, err := matrix.FromDataNode(dn)
|
||||
if err != nil {
|
||||
fmt.Fprintln(cmd.ErrOrStderr(), "Error creating matrix:", err)
|
||||
return
|
||||
return fmt.Errorf("error creating matrix: %w", err)
|
||||
}
|
||||
data, err = matrix.ToTar()
|
||||
if err != nil {
|
||||
fmt.Fprintln(cmd.ErrOrStderr(), "Error serializing matrix:", err)
|
||||
return
|
||||
return fmt.Errorf("error serializing matrix: %w", err)
|
||||
}
|
||||
} else {
|
||||
data, err = dn.ToTar()
|
||||
if err != nil {
|
||||
fmt.Fprintln(cmd.ErrOrStderr(), "Error serializing DataNode:", err)
|
||||
return
|
||||
return fmt.Errorf("error serializing DataNode: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
compressedData, err := compress.Compress(data, compression)
|
||||
if err != nil {
|
||||
fmt.Fprintln(cmd.ErrOrStderr(), "Error compressing data:", err)
|
||||
return
|
||||
return fmt.Errorf("error compressing data: %w", err)
|
||||
}
|
||||
|
||||
if outputFile == "" {
|
||||
|
|
@ -75,15 +70,14 @@ var collectWebsiteCmd = &cobra.Command{
|
|||
|
||||
err = os.WriteFile(outputFile, compressedData, 0644)
|
||||
if err != nil {
|
||||
fmt.Fprintln(cmd.ErrOrStderr(), "Error writing website to file:", err)
|
||||
return
|
||||
return fmt.Errorf("error writing website to file: %w", err)
|
||||
}
|
||||
|
||||
fmt.Fprintln(cmd.OutOrStdout(), "Website saved to", outputFile)
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
// init registers the 'collect website' subcommand and its flags.
|
||||
func init() {
|
||||
collectCmd.AddCommand(collectWebsiteCmd)
|
||||
collectWebsiteCmd.PersistentFlags().String("output", "", "Output file for the DataNode")
|
||||
|
|
@ -91,3 +85,6 @@ func init() {
|
|||
collectWebsiteCmd.PersistentFlags().String("format", "datanode", "Output format (datanode or matrix)")
|
||||
collectWebsiteCmd.PersistentFlags().String("compression", "none", "Compression format (none, gz, or xz)")
|
||||
}
|
||||
func NewCollectWebsiteCmd() *cobra.Command {
|
||||
return collectWebsiteCmd
|
||||
}
|
||||
|
|
|
|||
70
cmd/collect_website_test.go
Normal file
70
cmd/collect_website_test.go
Normal file
|
|
@ -0,0 +1,70 @@
|
|||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/Snider/Borg/pkg/datanode"
|
||||
"github.com/Snider/Borg/pkg/website"
|
||||
"github.com/schollz/progressbar/v3"
|
||||
)
|
||||
|
||||
func TestCollectWebsiteCmd_NoArgs(t *testing.T) {
|
||||
rootCmd := NewRootCmd()
|
||||
collectCmd := NewCollectCmd()
|
||||
collectWebsiteCmd := NewCollectWebsiteCmd()
|
||||
collectCmd.AddCommand(collectWebsiteCmd)
|
||||
rootCmd.AddCommand(collectCmd)
|
||||
_, err := executeCommand(rootCmd, "collect", "website")
|
||||
if err == nil {
|
||||
t.Fatalf("expected an error, but got none")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "accepts 1 arg(s), received 0") {
|
||||
t.Fatalf("unexpected error message: %v", err)
|
||||
}
|
||||
}
|
||||
func Test_NewCollectWebsiteCmd(t *testing.T) {
|
||||
if NewCollectWebsiteCmd() == nil {
|
||||
t.Errorf("NewCollectWebsiteCmd is nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCollectWebsiteCmd_Good(t *testing.T) {
|
||||
oldDownloadAndPackageWebsite := website.DownloadAndPackageWebsite
|
||||
website.DownloadAndPackageWebsite = func(startURL string, maxDepth int, bar *progressbar.ProgressBar) (*datanode.DataNode, error) {
|
||||
return datanode.New(), nil
|
||||
}
|
||||
defer func() {
|
||||
website.DownloadAndPackageWebsite = oldDownloadAndPackageWebsite
|
||||
}()
|
||||
|
||||
rootCmd := NewRootCmd()
|
||||
rootCmd.AddCommand(collectCmd)
|
||||
|
||||
out := filepath.Join(t.TempDir(), "out")
|
||||
_, err := executeCommand(rootCmd, "collect", "website", "https://example.com", "--output", out)
|
||||
if err != nil {
|
||||
t.Fatalf("collect website command failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCollectWebsiteCmd_Bad(t *testing.T) {
|
||||
oldDownloadAndPackageWebsite := website.DownloadAndPackageWebsite
|
||||
website.DownloadAndPackageWebsite = func(startURL string, maxDepth int, bar *progressbar.ProgressBar) (*datanode.DataNode, error) {
|
||||
return nil, fmt.Errorf("website error")
|
||||
}
|
||||
defer func() {
|
||||
website.DownloadAndPackageWebsite = oldDownloadAndPackageWebsite
|
||||
}()
|
||||
|
||||
rootCmd := NewRootCmd()
|
||||
rootCmd.AddCommand(collectCmd)
|
||||
|
||||
out := filepath.Join(t.TempDir(), "out")
|
||||
_, err := executeCommand(rootCmd, "collect", "website", "https://example.com", "--output", out)
|
||||
if err == nil {
|
||||
t.Fatalf("expected an error, but got none")
|
||||
}
|
||||
}
|
||||
|
|
@ -1,18 +0,0 @@
|
|||
package cmd
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestExecute(t *testing.T) {
|
||||
// This test simply checks that the Execute function can be called without error.
|
||||
// It doesn't actually test any of the application's functionality.
|
||||
log := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
|
||||
Level: slog.LevelDebug,
|
||||
}))
|
||||
if err := Execute(log); err != nil {
|
||||
t.Errorf("Execute() failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
|
@ -7,7 +7,6 @@ import (
|
|||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
// NewRootCmd constructs the root cobra.Command for the Borg CLI and wires common flags.
|
||||
func NewRootCmd() *cobra.Command {
|
||||
rootCmd := &cobra.Command{
|
||||
Use: "borg",
|
||||
|
|
|
|||
64
cmd/root_test.go
Normal file
64
cmd/root_test.go
Normal file
|
|
@ -0,0 +1,64 @@
|
|||
package cmd
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"log/slog"
|
||||
"testing"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
func TestExecute(t *testing.T) {
|
||||
err := Execute(slog.New(slog.NewTextHandler(io.Discard, nil)))
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func executeCommand(cmd *cobra.Command, args ...string) (string, error) {
|
||||
buf := new(bytes.Buffer)
|
||||
cmd.SetOut(buf)
|
||||
cmd.SetErr(buf)
|
||||
cmd.SetArgs(args)
|
||||
|
||||
err := cmd.Execute()
|
||||
return buf.String(), err
|
||||
}
|
||||
|
||||
func Test_NewRootCmd(t *testing.T) {
|
||||
if NewRootCmd() == nil {
|
||||
t.Errorf("NewRootCmd is nil")
|
||||
}
|
||||
}
|
||||
func Test_executeCommand(t *testing.T) {
|
||||
type args struct {
|
||||
cmd *cobra.Command
|
||||
args []string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "Test with no args",
|
||||
args: args{
|
||||
cmd: NewRootCmd(),
|
||||
args: []string{},
|
||||
},
|
||||
want: "",
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := executeCommand(tt.args.cmd, tt.args.args...)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("executeCommand() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
@ -62,7 +62,6 @@ var serveCmd = &cobra.Command{
|
|||
},
|
||||
}
|
||||
|
||||
// init registers the 'serve' command and its flags with the root command.
|
||||
func init() {
|
||||
RootCmd.AddCommand(serveCmd)
|
||||
serveCmd.PersistentFlags().String("port", "8080", "Port to serve the PWA on")
|
||||
|
|
|
|||
8
main.go
8
main.go
|
|
@ -7,12 +7,16 @@ import (
|
|||
"github.com/Snider/Borg/pkg/logger"
|
||||
)
|
||||
|
||||
// main is the entry point for the Borg CLI application.
|
||||
var osExit = os.Exit
|
||||
|
||||
func main() {
|
||||
Main()
|
||||
}
|
||||
func Main() {
|
||||
verbose, _ := cmd.RootCmd.PersistentFlags().GetBool("verbose")
|
||||
log := logger.New(verbose)
|
||||
if err := cmd.Execute(log); err != nil {
|
||||
log.Error("fatal error", "err", err)
|
||||
os.Exit(1)
|
||||
osExit(1)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
59
pkg/compress/compress_test.go
Normal file
59
pkg/compress/compress_test.go
Normal file
|
|
@ -0,0 +1,59 @@
|
|||
package compress
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCompressDecompress(t *testing.T) {
|
||||
testData := []byte("hello, world")
|
||||
|
||||
// Test gzip compression
|
||||
compressedGz, err := Compress(testData, "gz")
|
||||
if err != nil {
|
||||
t.Fatalf("gzip compression failed: %v", err)
|
||||
}
|
||||
|
||||
decompressedGz, err := Decompress(compressedGz)
|
||||
if err != nil {
|
||||
t.Fatalf("gzip decompression failed: %v", err)
|
||||
}
|
||||
|
||||
if !bytes.Equal(testData, decompressedGz) {
|
||||
t.Errorf("gzip decompressed data does not match original data")
|
||||
}
|
||||
|
||||
// Test xz compression
|
||||
compressedXz, err := Compress(testData, "xz")
|
||||
if err != nil {
|
||||
t.Fatalf("xz compression failed: %v", err)
|
||||
}
|
||||
|
||||
decompressedXz, err := Decompress(compressedXz)
|
||||
if err != nil {
|
||||
t.Fatalf("xz decompression failed: %v", err)
|
||||
}
|
||||
|
||||
if !bytes.Equal(testData, decompressedXz) {
|
||||
t.Errorf("xz decompressed data does not match original data")
|
||||
}
|
||||
|
||||
// Test no compression
|
||||
compressedNone, err := Compress(testData, "none")
|
||||
if err != nil {
|
||||
t.Fatalf("no compression failed: %v", err)
|
||||
}
|
||||
|
||||
if !bytes.Equal(testData, compressedNone) {
|
||||
t.Errorf("no compression data does not match original data")
|
||||
}
|
||||
|
||||
decompressedNone, err := Decompress(compressedNone)
|
||||
if err != nil {
|
||||
t.Fatalf("no compression decompression failed: %v", err)
|
||||
}
|
||||
|
||||
if !bytes.Equal(testData, decompressedNone) {
|
||||
t.Errorf("no compression decompressed data does not match original data")
|
||||
}
|
||||
}
|
||||
|
|
@ -260,35 +260,19 @@ type dataFile struct {
|
|||
modTime time.Time
|
||||
}
|
||||
|
||||
// Stat implements fs.File.Stat for a dataFile and returns its FileInfo.
|
||||
func (d *dataFile) Stat() (fs.FileInfo, error) { return &dataFileInfo{file: d}, nil }
|
||||
|
||||
// Read implements fs.File.Read for a dataFile and reports EOF as files are in-memory.
|
||||
func (d *dataFile) Read(p []byte) (int, error) { return 0, io.EOF }
|
||||
|
||||
// Close implements fs.File.Close for a dataFile; it's a no-op.
|
||||
func (d *dataFile) Close() error { return nil }
|
||||
func (d *dataFile) Close() error { return nil }
|
||||
|
||||
// dataFileInfo implements fs.FileInfo for a dataFile.
|
||||
type dataFileInfo struct{ file *dataFile }
|
||||
|
||||
// Name returns the base name of the data file.
|
||||
func (d *dataFileInfo) Name() string { return path.Base(d.file.name) }
|
||||
|
||||
// Size returns the size of the data file in bytes.
|
||||
func (d *dataFileInfo) Size() int64 { return int64(len(d.file.content)) }
|
||||
|
||||
// Mode returns the file mode for data files (read-only).
|
||||
func (d *dataFileInfo) Mode() fs.FileMode { return 0444 }
|
||||
|
||||
// ModTime returns the modification time of the data file.
|
||||
func (d *dataFileInfo) Name() string { return path.Base(d.file.name) }
|
||||
func (d *dataFileInfo) Size() int64 { return int64(len(d.file.content)) }
|
||||
func (d *dataFileInfo) Mode() fs.FileMode { return 0444 }
|
||||
func (d *dataFileInfo) ModTime() time.Time { return d.file.modTime }
|
||||
|
||||
// IsDir reports whether the entry is a directory (false for data files).
|
||||
func (d *dataFileInfo) IsDir() bool { return false }
|
||||
|
||||
// Sys returns system-specific data, which is nil for data files.
|
||||
func (d *dataFileInfo) Sys() interface{} { return nil }
|
||||
func (d *dataFileInfo) IsDir() bool { return false }
|
||||
func (d *dataFileInfo) Sys() interface{} { return nil }
|
||||
|
||||
// dataFileReader implements fs.File for a dataFile.
|
||||
type dataFileReader struct {
|
||||
|
|
@ -296,18 +280,13 @@ type dataFileReader struct {
|
|||
reader *bytes.Reader
|
||||
}
|
||||
|
||||
// Stat returns the FileInfo for the underlying data file.
|
||||
func (d *dataFileReader) Stat() (fs.FileInfo, error) { return d.file.Stat() }
|
||||
|
||||
// Read reads from the underlying data file content.
|
||||
func (d *dataFileReader) Read(p []byte) (int, error) {
|
||||
if d.reader == nil {
|
||||
d.reader = bytes.NewReader(d.file.content)
|
||||
}
|
||||
return d.reader.Read(p)
|
||||
}
|
||||
|
||||
// Close closes the data file reader (no-op).
|
||||
func (d *dataFileReader) Close() error { return nil }
|
||||
|
||||
// dirInfo implements fs.FileInfo for an implicit directory.
|
||||
|
|
@ -316,23 +295,12 @@ type dirInfo struct {
|
|||
modTime time.Time
|
||||
}
|
||||
|
||||
// Name returns the directory name.
|
||||
func (d *dirInfo) Name() string { return d.name }
|
||||
|
||||
// Size returns the size of the directory entry (always 0 for virtual dirs).
|
||||
func (d *dirInfo) Size() int64 { return 0 }
|
||||
|
||||
// Mode returns the file mode for directories.
|
||||
func (d *dirInfo) Mode() fs.FileMode { return fs.ModeDir | 0555 }
|
||||
|
||||
// ModTime returns the modification time of the directory.
|
||||
func (d *dirInfo) Name() string { return d.name }
|
||||
func (d *dirInfo) Size() int64 { return 0 }
|
||||
func (d *dirInfo) Mode() fs.FileMode { return fs.ModeDir | 0555 }
|
||||
func (d *dirInfo) ModTime() time.Time { return d.modTime }
|
||||
|
||||
// IsDir reports that the entry is a directory.
|
||||
func (d *dirInfo) IsDir() bool { return true }
|
||||
|
||||
// Sys returns system-specific data, which is nil for dirs.
|
||||
func (d *dirInfo) Sys() interface{} { return nil }
|
||||
func (d *dirInfo) IsDir() bool { return true }
|
||||
func (d *dirInfo) Sys() interface{} { return nil }
|
||||
|
||||
// dirFile implements fs.File for a directory.
|
||||
type dirFile struct {
|
||||
|
|
@ -340,15 +308,10 @@ type dirFile struct {
|
|||
modTime time.Time
|
||||
}
|
||||
|
||||
// Stat returns the FileInfo for the directory.
|
||||
func (d *dirFile) Stat() (fs.FileInfo, error) {
|
||||
return &dirInfo{name: path.Base(d.path), modTime: d.modTime}, nil
|
||||
}
|
||||
|
||||
// Read is invalid for directories and returns an error.
|
||||
func (d *dirFile) Read([]byte) (int, error) {
|
||||
return 0, &fs.PathError{Op: "read", Path: d.path, Err: fs.ErrInvalid}
|
||||
}
|
||||
|
||||
// Close closes the directory file (no-op).
|
||||
func (d *dirFile) Close() error { return nil }
|
||||
|
|
|
|||
|
|
@ -27,13 +27,8 @@ func NewGithubClient() GithubClient {
|
|||
|
||||
type githubClient struct{}
|
||||
|
||||
func (g *githubClient) GetPublicRepos(ctx context.Context, userOrOrg string) ([]string, error) {
|
||||
return g.getPublicReposWithAPIURL(ctx, "https://api.github.com", userOrOrg)
|
||||
}
|
||||
|
||||
// newAuthenticatedClient returns an *http.Client that uses GITHUB_TOKEN if set.
|
||||
// When no token is present, it falls back to http.DefaultClient.
|
||||
func (g *githubClient) newAuthenticatedClient(ctx context.Context) *http.Client {
|
||||
// NewAuthenticatedClient creates a new authenticated http client.
|
||||
var NewAuthenticatedClient = func(ctx context.Context) *http.Client {
|
||||
token := os.Getenv("GITHUB_TOKEN")
|
||||
if token == "" {
|
||||
return http.DefaultClient
|
||||
|
|
@ -44,10 +39,12 @@ func (g *githubClient) newAuthenticatedClient(ctx context.Context) *http.Client
|
|||
return oauth2.NewClient(ctx, ts)
|
||||
}
|
||||
|
||||
// getPublicReposWithAPIURL fetches public repository clone URLs for the given user/org
|
||||
// using the provided GitHub API base URL, following pagination links as needed.
|
||||
func (g *githubClient) GetPublicRepos(ctx context.Context, userOrOrg string) ([]string, error) {
|
||||
return g.getPublicReposWithAPIURL(ctx, "https://api.github.com", userOrOrg)
|
||||
}
|
||||
|
||||
func (g *githubClient) getPublicReposWithAPIURL(ctx context.Context, apiURL, userOrOrg string) ([]string, error) {
|
||||
client := g.newAuthenticatedClient(ctx)
|
||||
client := NewAuthenticatedClient(ctx)
|
||||
var allCloneURLs []string
|
||||
url := fmt.Sprintf("%s/users/%s/repos", apiURL, userOrOrg)
|
||||
|
||||
|
|
@ -110,7 +107,6 @@ func (g *githubClient) getPublicReposWithAPIURL(ctx context.Context, apiURL, use
|
|||
return allCloneURLs, nil
|
||||
}
|
||||
|
||||
// findNextURL parses the HTTP Link header and returns the URL with rel="next", if any.
|
||||
func (g *githubClient) findNextURL(linkHeader string) string {
|
||||
links := strings.Split(linkHeader, ",")
|
||||
for _, link := range links {
|
||||
|
|
|
|||
124
pkg/github/github_test.go
Normal file
124
pkg/github/github_test.go
Normal file
|
|
@ -0,0 +1,124 @@
|
|||
package github
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/Snider/Borg/pkg/mocks"
|
||||
)
|
||||
|
||||
func TestGetPublicRepos(t *testing.T) {
|
||||
mockClient := mocks.NewMockClient(map[string]*http.Response{
|
||||
"https://api.github.com/users/testuser/repos": {
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
Body: io.NopCloser(bytes.NewBufferString(`[{"clone_url": "https://github.com/testuser/repo1.git"}]`)),
|
||||
},
|
||||
"https://api.github.com/orgs/testorg/repos": {
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}, "Link": []string{`<https://api.github.com/organizations/123/repos?page=2>; rel="next"`}},
|
||||
Body: io.NopCloser(bytes.NewBufferString(`[{"clone_url": "https://github.com/testorg/repo1.git"}]`)),
|
||||
},
|
||||
"https://api.github.com/organizations/123/repos?page=2": {
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
Body: io.NopCloser(bytes.NewBufferString(`[{"clone_url": "https://github.com/testorg/repo2.git"}]`)),
|
||||
},
|
||||
})
|
||||
|
||||
client := &githubClient{}
|
||||
oldClient := NewAuthenticatedClient
|
||||
NewAuthenticatedClient = func(ctx context.Context) *http.Client {
|
||||
return mockClient
|
||||
}
|
||||
defer func() {
|
||||
NewAuthenticatedClient = oldClient
|
||||
}()
|
||||
|
||||
// Test user repos
|
||||
repos, err := client.getPublicReposWithAPIURL(context.Background(), "https://api.github.com", "testuser")
|
||||
if err != nil {
|
||||
t.Fatalf("getPublicReposWithAPIURL for user failed: %v", err)
|
||||
}
|
||||
if len(repos) != 1 || repos[0] != "https://github.com/testuser/repo1.git" {
|
||||
t.Errorf("unexpected user repos: %v", repos)
|
||||
}
|
||||
|
||||
// Test org repos with pagination
|
||||
repos, err = client.getPublicReposWithAPIURL(context.Background(), "https://api.github.com", "testorg")
|
||||
if err != nil {
|
||||
t.Fatalf("getPublicReposWithAPIURL for org failed: %v", err)
|
||||
}
|
||||
if len(repos) != 2 || repos[0] != "https://github.com/testorg/repo1.git" || repos[1] != "https://github.com/testorg/repo2.git" {
|
||||
t.Errorf("unexpected org repos: %v", repos)
|
||||
}
|
||||
}
|
||||
func TestGetPublicRepos_Error(t *testing.T) {
|
||||
u, _ := url.Parse("https://api.github.com/users/testuser/repos")
|
||||
mockClient := mocks.NewMockClient(map[string]*http.Response{
|
||||
"https://api.github.com/users/testuser/repos": {
|
||||
StatusCode: http.StatusNotFound,
|
||||
Status: "404 Not Found",
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
Body: io.NopCloser(bytes.NewBufferString("")),
|
||||
Request: &http.Request{Method: "GET", URL: u},
|
||||
},
|
||||
"https://api.github.com/orgs/testuser/repos": {
|
||||
StatusCode: http.StatusNotFound,
|
||||
Status: "404 Not Found",
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
Body: io.NopCloser(bytes.NewBufferString("")),
|
||||
Request: &http.Request{Method: "GET", URL: u},
|
||||
},
|
||||
})
|
||||
expectedErr := "failed to fetch repos: 404 Not Found"
|
||||
|
||||
client := &githubClient{}
|
||||
oldClient := NewAuthenticatedClient
|
||||
NewAuthenticatedClient = func(ctx context.Context) *http.Client {
|
||||
return mockClient
|
||||
}
|
||||
defer func() {
|
||||
NewAuthenticatedClient = oldClient
|
||||
}()
|
||||
|
||||
// Test user repos
|
||||
_, err := client.getPublicReposWithAPIURL(context.Background(), "https://api.github.com", "testuser")
|
||||
if err.Error() != expectedErr {
|
||||
t.Fatalf("getPublicReposWithAPIURL for user failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindNextURL(t *testing.T) {
|
||||
client := &githubClient{}
|
||||
linkHeader := `<https://api.github.com/organizations/123/repos?page=2>; rel="next", <https://api.github.com/organizations/123/repos?page=1>; rel="prev"`
|
||||
nextURL := client.findNextURL(linkHeader)
|
||||
if nextURL != "https://api.github.com/organizations/123/repos?page=2" {
|
||||
t.Errorf("unexpected next URL: %s", nextURL)
|
||||
}
|
||||
|
||||
linkHeader = `<https://api.github.com/organizations/123/repos?page=1>; rel="prev"`
|
||||
nextURL = client.findNextURL(linkHeader)
|
||||
if nextURL != "" {
|
||||
t.Errorf("unexpected next URL: %s", nextURL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewAuthenticatedClient(t *testing.T) {
|
||||
// Test with no token
|
||||
client := NewAuthenticatedClient(context.Background())
|
||||
if client != http.DefaultClient {
|
||||
t.Errorf("expected http.DefaultClient, but got something else")
|
||||
}
|
||||
|
||||
// Test with token
|
||||
t.Setenv("GITHUB_TOKEN", "test-token")
|
||||
client = NewAuthenticatedClient(context.Background())
|
||||
if client == http.DefaultClient {
|
||||
t.Errorf("expected an authenticated client, but got http.DefaultClient")
|
||||
}
|
||||
}
|
||||
|
|
@ -1,19 +1,33 @@
|
|||
package github
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/google/go-github/v39/github"
|
||||
)
|
||||
|
||||
var (
|
||||
// NewClient is a variable that holds the function to create a new GitHub client.
|
||||
// This allows for mocking in tests.
|
||||
NewClient = func(httpClient *http.Client) *github.Client {
|
||||
return github.NewClient(httpClient)
|
||||
}
|
||||
// NewRequest is a variable that holds the function to create a new HTTP request.
|
||||
NewRequest = func(method, url string, body io.Reader) (*http.Request, error) {
|
||||
return http.NewRequest(method, url, body)
|
||||
}
|
||||
// DefaultClient is the default http client
|
||||
DefaultClient = &http.Client{}
|
||||
)
|
||||
|
||||
// GetLatestRelease gets the latest release for a repository.
|
||||
func GetLatestRelease(owner, repo string) (*github.RepositoryRelease, error) {
|
||||
client := github.NewClient(nil)
|
||||
client := NewClient(nil)
|
||||
release, _, err := client.Repositories.GetLatestRelease(context.Background(), owner, repo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
@ -22,32 +36,29 @@ func GetLatestRelease(owner, repo string) (*github.RepositoryRelease, error) {
|
|||
}
|
||||
|
||||
// DownloadReleaseAsset downloads a release asset.
|
||||
func DownloadReleaseAsset(asset *github.ReleaseAsset, path string) error {
|
||||
client := &http.Client{}
|
||||
req, err := http.NewRequest("GET", asset.GetBrowserDownloadURL(), nil)
|
||||
func DownloadReleaseAsset(asset *github.ReleaseAsset) ([]byte, error) {
|
||||
req, err := NewRequest("GET", asset.GetBrowserDownloadURL(), nil)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Accept", "application/octet-stream")
|
||||
|
||||
resp, err := client.Do(req)
|
||||
resp, err := DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("bad status: %s", resp.Status)
|
||||
return nil, fmt.Errorf("bad status: %s", resp.Status)
|
||||
}
|
||||
|
||||
f, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY, 0644)
|
||||
buf := new(bytes.Buffer)
|
||||
_, err = io.Copy(buf, resp.Body)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
_, err = io.Copy(f, resp.Body)
|
||||
return err
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
// ParseRepoFromURL parses the owner and repository from a GitHub URL.
|
||||
|
|
|
|||
198
pkg/github/release_test.go
Normal file
198
pkg/github/release_test.go
Normal file
|
|
@ -0,0 +1,198 @@
|
|||
package github
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/Snider/Borg/pkg/mocks"
|
||||
"github.com/google/go-github/v39/github"
|
||||
)
|
||||
|
||||
type errorRoundTripper struct{}
|
||||
|
||||
func (e *errorRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return nil, fmt.Errorf("do error")
|
||||
}
|
||||
|
||||
func TestParseRepoFromURL(t *testing.T) {
|
||||
testCases := []struct {
|
||||
url string
|
||||
owner string
|
||||
repo string
|
||||
expectErr bool
|
||||
}{
|
||||
{"https://github.com/owner/repo.git", "owner", "repo", false},
|
||||
{"http://github.com/owner/repo", "owner", "repo", false},
|
||||
{"git://github.com/owner/repo.git", "owner", "repo", false},
|
||||
{"github.com/owner/repo", "owner", "repo", false},
|
||||
{"git@github.com:owner/repo.git", "owner", "repo", false},
|
||||
{"https://github.com/owner/repo/tree/main", "", "", true},
|
||||
{"invalid-url", "", "", true},
|
||||
{"git:github.com:owner/repo.git", "owner", "repo", false},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
owner, repo, err := ParseRepoFromURL(tc.url)
|
||||
if (err != nil) != tc.expectErr {
|
||||
t.Errorf("unexpected error for URL %s: %v", tc.url, err)
|
||||
}
|
||||
if owner != tc.owner || repo != tc.repo {
|
||||
t.Errorf("unexpected owner/repo for URL %s: %s/%s", tc.url, owner, repo)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetLatestRelease(t *testing.T) {
|
||||
oldNewClient := NewClient
|
||||
t.Cleanup(func() { NewClient = oldNewClient })
|
||||
|
||||
mockClient := mocks.NewMockClient(map[string]*http.Response{
|
||||
"https://api.github.com/repos/owner/repo/releases/latest": {
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(bytes.NewBufferString(`{"tag_name": "v1.0.0"}`)),
|
||||
},
|
||||
})
|
||||
client := github.NewClient(mockClient)
|
||||
NewClient = func(_ *http.Client) *github.Client {
|
||||
return client
|
||||
}
|
||||
|
||||
release, err := GetLatestRelease("owner", "repo")
|
||||
if err != nil {
|
||||
t.Fatalf("GetLatestRelease failed: %v", err)
|
||||
}
|
||||
|
||||
if release.GetTagName() != "v1.0.0" {
|
||||
t.Errorf("unexpected tag name: %s", release.GetTagName())
|
||||
}
|
||||
}
|
||||
|
||||
func TestDownloadReleaseAsset(t *testing.T) {
|
||||
mockClient := mocks.NewMockClient(map[string]*http.Response{
|
||||
"https://github.com/owner/repo/releases/download/v1.0.0/asset.zip": {
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(bytes.NewBufferString("asset content")),
|
||||
},
|
||||
})
|
||||
|
||||
asset := &github.ReleaseAsset{
|
||||
BrowserDownloadURL: github.String("https://github.com/owner/repo/releases/download/v1.0.0/asset.zip"),
|
||||
}
|
||||
|
||||
oldClient := DefaultClient
|
||||
DefaultClient = mockClient
|
||||
defer func() {
|
||||
DefaultClient = oldClient
|
||||
}()
|
||||
|
||||
data, err := DownloadReleaseAsset(asset)
|
||||
if err != nil {
|
||||
t.Fatalf("DownloadReleaseAsset failed: %v", err)
|
||||
}
|
||||
|
||||
if string(data) != "asset content" {
|
||||
t.Errorf("unexpected asset content: %s", string(data))
|
||||
}
|
||||
}
|
||||
func TestDownloadReleaseAsset_BadRequest(t *testing.T) {
|
||||
mockClient := mocks.NewMockClient(map[string]*http.Response{
|
||||
"https://github.com/owner/repo/releases/download/v1.0.0/asset.zip": {
|
||||
StatusCode: http.StatusBadRequest,
|
||||
Status: "400 Bad Request",
|
||||
Body: io.NopCloser(bytes.NewBufferString("")),
|
||||
},
|
||||
})
|
||||
expectedErr := "bad status: 400 Bad Request"
|
||||
|
||||
asset := &github.ReleaseAsset{
|
||||
BrowserDownloadURL: github.String("https://github.com/owner/repo/releases/download/v1.0.0/asset.zip"),
|
||||
}
|
||||
|
||||
oldClient := DefaultClient
|
||||
DefaultClient = mockClient
|
||||
defer func() {
|
||||
DefaultClient = oldClient
|
||||
}()
|
||||
|
||||
_, err := DownloadReleaseAsset(asset)
|
||||
if err == nil {
|
||||
t.Fatalf("expected error but got nil")
|
||||
}
|
||||
if err.Error() != expectedErr {
|
||||
t.Fatalf("DownloadReleaseAsset failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDownloadReleaseAsset_NewRequestError(t *testing.T) {
|
||||
errRequest := fmt.Errorf("bad request")
|
||||
asset := &github.ReleaseAsset{
|
||||
BrowserDownloadURL: github.String("https://github.com/owner/repo/releases/download/v1.0.0/asset.zip"),
|
||||
}
|
||||
|
||||
oldNewRequest := NewRequest
|
||||
NewRequest = func(method, url string, body io.Reader) (*http.Request, error) {
|
||||
return nil, errRequest
|
||||
}
|
||||
defer func() {
|
||||
NewRequest = oldNewRequest
|
||||
}()
|
||||
|
||||
_, err := DownloadReleaseAsset(asset)
|
||||
if err == nil {
|
||||
t.Fatalf("expected error but got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetLatestRelease_Error(t *testing.T) {
|
||||
oldNewClient := NewClient
|
||||
t.Cleanup(func() { NewClient = oldNewClient })
|
||||
|
||||
u, _ := url.Parse("https://api.github.com/repos/owner/repo/releases/latest")
|
||||
mockClient := mocks.NewMockClient(map[string]*http.Response{
|
||||
"https://api.github.com/repos/owner/repo/releases/latest": {
|
||||
StatusCode: http.StatusNotFound,
|
||||
Status: "404 Not Found",
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
Body: io.NopCloser(bytes.NewBufferString("")),
|
||||
Request: &http.Request{Method: "GET", URL: u},
|
||||
},
|
||||
})
|
||||
expectedErr := "GET https://api.github.com/repos/owner/repo/releases/latest: 404 []"
|
||||
client := github.NewClient(mockClient)
|
||||
NewClient = func(_ *http.Client) *github.Client {
|
||||
return client
|
||||
}
|
||||
|
||||
_, err := GetLatestRelease("owner", "repo")
|
||||
if err == nil {
|
||||
t.Fatalf("expected error but got nil")
|
||||
}
|
||||
if err.Error() != expectedErr {
|
||||
t.Fatalf("GetLatestRelease failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDownloadReleaseAsset_DoError(t *testing.T) {
|
||||
mockClient := &http.Client{
|
||||
Transport: &errorRoundTripper{},
|
||||
}
|
||||
|
||||
asset := &github.ReleaseAsset{
|
||||
BrowserDownloadURL: github.String("https://github.com/owner/repo/releases/download/v1.0.0/asset.zip"),
|
||||
}
|
||||
|
||||
oldClient := DefaultClient
|
||||
DefaultClient = mockClient
|
||||
defer func() {
|
||||
DefaultClient = oldClient
|
||||
}()
|
||||
|
||||
_, err := DownloadReleaseAsset(asset)
|
||||
if err == nil {
|
||||
t.Fatalf("DownloadReleaseAsset should have failed")
|
||||
}
|
||||
}
|
||||
|
|
@ -5,8 +5,6 @@ import (
|
|||
"os"
|
||||
)
|
||||
|
||||
// New constructs a slog.Logger configured for stderr with the given verbosity.
|
||||
// When verbose is true, the logger emits debug-level messages; otherwise info-level.
|
||||
func New(verbose bool) *slog.Logger {
|
||||
level := slog.LevelInfo
|
||||
if verbose {
|
||||
|
|
|
|||
64
pkg/logger/logger_test.go
Normal file
64
pkg/logger/logger_test.go
Normal file
|
|
@ -0,0 +1,64 @@
|
|||
package logger
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"io"
|
||||
"os"
|
||||
"log/slog"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
// Test non-verbose logger
|
||||
var buf bytes.Buffer
|
||||
oldStderr := os.Stderr
|
||||
r, w, _ := os.Pipe()
|
||||
os.Stderr = w
|
||||
|
||||
log := New(false)
|
||||
log.Info("info message")
|
||||
log.Debug("debug message")
|
||||
|
||||
w.Close()
|
||||
os.Stderr = oldStderr
|
||||
io.Copy(&buf, r)
|
||||
|
||||
if !bytes.Contains(buf.Bytes(), []byte("info message")) {
|
||||
t.Errorf("expected info message to be logged")
|
||||
}
|
||||
if bytes.Contains(buf.Bytes(), []byte("debug message")) {
|
||||
t.Errorf("expected debug message not to be logged")
|
||||
}
|
||||
|
||||
// Test verbose logger
|
||||
buf.Reset()
|
||||
r, w, _ = os.Pipe()
|
||||
os.Stderr = w
|
||||
|
||||
log = New(true)
|
||||
log.Info("info message")
|
||||
log.Debug("debug message")
|
||||
|
||||
w.Close()
|
||||
os.Stderr = oldStderr
|
||||
io.Copy(&buf, r)
|
||||
|
||||
if !bytes.Contains(buf.Bytes(), []byte("info message")) {
|
||||
t.Errorf("expected info message to be logged")
|
||||
}
|
||||
if !bytes.Contains(buf.Bytes(), []byte("debug message")) {
|
||||
t.Errorf("expected debug message to be logged")
|
||||
}
|
||||
}
|
||||
func TestNew_Level(t *testing.T) {
|
||||
log := New(true)
|
||||
if !log.Enabled(context.Background(), slog.LevelDebug) {
|
||||
t.Errorf("expected debug level to be enabled")
|
||||
}
|
||||
|
||||
log = New(false)
|
||||
if log.Enabled(context.Background(), slog.LevelDebug) {
|
||||
t.Errorf("expected debug level to be disabled")
|
||||
}
|
||||
}
|
||||
27
pkg/mocks/mock_vcs.go
Normal file
27
pkg/mocks/mock_vcs.go
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
package mocks
|
||||
|
||||
import (
|
||||
"io"
|
||||
|
||||
"github.com/Snider/Borg/pkg/datanode"
|
||||
"github.com/Snider/Borg/pkg/vcs"
|
||||
)
|
||||
|
||||
// MockGitCloner is a mock implementation of the GitCloner interface.
|
||||
type MockGitCloner struct {
|
||||
DN *datanode.DataNode
|
||||
Err error
|
||||
}
|
||||
|
||||
// NewMockGitCloner creates a new MockGitCloner.
|
||||
func NewMockGitCloner(dn *datanode.DataNode, err error) vcs.GitCloner {
|
||||
return &MockGitCloner{
|
||||
DN: dn,
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
// CloneGitRepository mocks the cloning of a Git repository.
|
||||
func (m *MockGitCloner) CloneGitRepository(repoURL string, progress io.Writer) (*datanode.DataNode, error) {
|
||||
return m.DN, m.Err
|
||||
}
|
||||
216
pkg/pwa/pwa.go
216
pkg/pwa/pwa.go
|
|
@ -6,49 +6,31 @@ import (
|
|||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path"
|
||||
"strings"
|
||||
|
||||
"github.com/Snider/Borg/pkg/datanode"
|
||||
"github.com/schollz/progressbar/v3"
|
||||
|
||||
"golang.org/x/net/html"
|
||||
)
|
||||
|
||||
// Manifest represents a simple PWA manifest structure.
|
||||
type Manifest struct {
|
||||
Name string `json:"name"`
|
||||
ShortName string `json:"short_name"`
|
||||
StartURL string `json:"start_url"`
|
||||
Icons []Icon `json:"icons"`
|
||||
}
|
||||
|
||||
// Icon represents an icon in the PWA manifest.
|
||||
type Icon struct {
|
||||
Src string `json:"src"`
|
||||
Sizes string `json:"sizes"`
|
||||
Type string `json:"type"`
|
||||
}
|
||||
|
||||
// PWAClient is an interface for finding and downloading PWAs.
|
||||
// PWAClient is an interface for interacting with PWAs.
|
||||
type PWAClient interface {
|
||||
FindManifest(pageURL string) (string, error)
|
||||
DownloadAndPackagePWA(baseURL string, manifestURL string, bar *progressbar.ProgressBar) (*datanode.DataNode, error)
|
||||
FindManifest(pwaURL string) (string, error)
|
||||
DownloadAndPackagePWA(pwaURL, manifestURL string, bar *progressbar.ProgressBar) (*datanode.DataNode, error)
|
||||
}
|
||||
|
||||
// NewPWAClient creates a new PWAClient.
|
||||
func NewPWAClient() PWAClient {
|
||||
return &pwaClient{
|
||||
client: http.DefaultClient,
|
||||
}
|
||||
return &pwaClient{client: http.DefaultClient}
|
||||
}
|
||||
|
||||
type pwaClient struct {
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
// FindManifest finds the manifest URL from a given HTML page.
|
||||
func (p *pwaClient) FindManifest(pageURL string) (string, error) {
|
||||
resp, err := p.client.Get(pageURL)
|
||||
// FindManifest finds the manifest for a PWA.
|
||||
func (p *pwaClient) FindManifest(pwaURL string) (string, error) {
|
||||
resp, err := p.client.Get(pwaURL)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
|
@ -59,106 +41,135 @@ func (p *pwaClient) FindManifest(pageURL string) (string, error) {
|
|||
return "", err
|
||||
}
|
||||
|
||||
var manifestPath string
|
||||
var manifestURL string
|
||||
var f func(*html.Node)
|
||||
f = func(n *html.Node) {
|
||||
if n.Type == html.ElementNode && n.Data == "link" {
|
||||
isManifest := false
|
||||
var isManifest bool
|
||||
var href string
|
||||
for _, a := range n.Attr {
|
||||
if a.Key == "rel" && a.Val == "manifest" {
|
||||
isManifest = true
|
||||
break
|
||||
}
|
||||
if a.Key == "href" {
|
||||
href = a.Val
|
||||
}
|
||||
}
|
||||
if isManifest {
|
||||
for _, a := range n.Attr {
|
||||
if a.Key == "href" {
|
||||
manifestPath = a.Val
|
||||
return // exit once found
|
||||
}
|
||||
}
|
||||
if isManifest && href != "" {
|
||||
manifestURL = href
|
||||
return
|
||||
}
|
||||
}
|
||||
for c := n.FirstChild; c != nil && manifestPath == ""; c = c.NextSibling {
|
||||
for c := n.FirstChild; c != nil && manifestURL == ""; c = c.NextSibling {
|
||||
f(c)
|
||||
}
|
||||
}
|
||||
f(doc)
|
||||
|
||||
if manifestPath == "" {
|
||||
if manifestURL == "" {
|
||||
return "", fmt.Errorf("manifest not found")
|
||||
}
|
||||
|
||||
resolvedURL, err := p.resolveURL(pageURL, manifestPath)
|
||||
resolvedURL, err := p.resolveURL(pwaURL, manifestURL)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("could not resolve manifest URL: %w", err)
|
||||
return "", err
|
||||
}
|
||||
|
||||
return resolvedURL.String(), nil
|
||||
}
|
||||
|
||||
// DownloadAndPackagePWA downloads all assets of a PWA and packages them into a DataNode.
|
||||
func (p *pwaClient) DownloadAndPackagePWA(baseURL string, manifestURL string, bar *progressbar.ProgressBar) (*datanode.DataNode, error) {
|
||||
if bar == nil {
|
||||
return nil, fmt.Errorf("progress bar cannot be nil")
|
||||
}
|
||||
manifestAbsURL, err := p.resolveURL(baseURL, manifestURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not resolve manifest URL: %w", err)
|
||||
// DownloadAndPackagePWA downloads and packages a PWA into a DataNode.
|
||||
func (p *pwaClient) DownloadAndPackagePWA(pwaURL, manifestURL string, bar *progressbar.ProgressBar) (*datanode.DataNode, error) {
|
||||
dn := datanode.New()
|
||||
|
||||
type Manifest struct {
|
||||
StartURL string `json:"start_url"`
|
||||
Icons []struct {
|
||||
Src string `json:"src"`
|
||||
} `json:"icons"`
|
||||
}
|
||||
|
||||
resp, err := p.client.Get(manifestAbsURL.String())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not download manifest: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
downloadAndAdd := func(assetURL string) error {
|
||||
if bar != nil {
|
||||
bar.Add(1)
|
||||
}
|
||||
resp, err := p.client.Get(assetURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to download %s: %w", assetURL, err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
manifestBody, err := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return fmt.Errorf("failed to download %s: status code %d", assetURL, resp.StatusCode)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read body of %s: %w", assetURL, err)
|
||||
}
|
||||
|
||||
u, err := url.Parse(assetURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse asset URL %s: %w", assetURL, err)
|
||||
}
|
||||
dn.AddData(strings.TrimPrefix(u.Path, "/"), body)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Download manifest
|
||||
if err := downloadAndAdd(manifestURL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Parse manifest and download assets
|
||||
var manifestPath string
|
||||
u, parseErr := url.Parse(manifestURL)
|
||||
if parseErr != nil {
|
||||
manifestPath = "manifest.json"
|
||||
} else {
|
||||
manifestPath = strings.TrimPrefix(u.Path, "/")
|
||||
}
|
||||
|
||||
manifestFile, err := dn.Open(manifestPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not read manifest body: %w", err)
|
||||
return nil, fmt.Errorf("failed to open manifest from datanode: %w", err)
|
||||
}
|
||||
defer manifestFile.Close()
|
||||
|
||||
manifestData, err := io.ReadAll(manifestFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read manifest from datanode: %w", err)
|
||||
}
|
||||
|
||||
var manifest Manifest
|
||||
if err := json.Unmarshal(manifestBody, &manifest); err != nil {
|
||||
return nil, fmt.Errorf("could not parse manifest JSON: %w", err)
|
||||
if err := json.Unmarshal(manifestData, &manifest); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse manifest: %w", err)
|
||||
}
|
||||
|
||||
dn := datanode.New()
|
||||
dn.AddData("manifest.json", manifestBody)
|
||||
|
||||
if manifest.StartURL != "" {
|
||||
startURLAbs, err := p.resolveURL(manifestAbsURL.String(), manifest.StartURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not resolve start_url: %w", err)
|
||||
}
|
||||
err = p.downloadAndAddFile(dn, startURLAbs, manifest.StartURL, bar)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to download start_url asset: %w", err)
|
||||
}
|
||||
// Download start_url
|
||||
startURL, err := p.resolveURL(manifestURL, manifest.StartURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to resolve start_url: %w", err)
|
||||
}
|
||||
if err := downloadAndAdd(startURL.String()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Download icons
|
||||
for _, icon := range manifest.Icons {
|
||||
iconURLAbs, err := p.resolveURL(manifestAbsURL.String(), icon.Src)
|
||||
iconURL, err := p.resolveURL(manifestURL, icon.Src)
|
||||
if err != nil {
|
||||
fmt.Printf("Warning: could not resolve icon URL %s: %v\n", icon.Src, err)
|
||||
// Skip icons with bad URLs
|
||||
continue
|
||||
}
|
||||
err = p.downloadAndAddFile(dn, iconURLAbs, icon.Src, bar)
|
||||
if err != nil {
|
||||
fmt.Printf("Warning: failed to download icon %s: %v\n", icon.Src, err)
|
||||
if err := downloadAndAdd(iconURL.String()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
baseURLAbs, _ := url.Parse(baseURL)
|
||||
err = p.downloadAndAddFile(dn, baseURLAbs, "index.html", bar)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to download base HTML: %w", err)
|
||||
}
|
||||
|
||||
return dn, nil
|
||||
}
|
||||
|
||||
// resolveURL resolves ref against base and returns the absolute URL.
|
||||
func (p *pwaClient) resolveURL(base, ref string) (*url.URL, error) {
|
||||
baseURL, err := url.Parse(base)
|
||||
if err != nil {
|
||||
|
|
@ -171,23 +182,28 @@ func (p *pwaClient) resolveURL(base, ref string) (*url.URL, error) {
|
|||
return baseURL.ResolveReference(refURL), nil
|
||||
}
|
||||
|
||||
// downloadAndAddFile downloads fileURL and adds it to dn at internalPath, updating the progress bar.
|
||||
func (p *pwaClient) downloadAndAddFile(dn *datanode.DataNode, fileURL *url.URL, internalPath string, bar *progressbar.ProgressBar) error {
|
||||
resp, err := p.client.Get(fileURL.String())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("bad status: %s", resp.Status)
|
||||
}
|
||||
|
||||
data, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
dn.AddData(path.Clean(internalPath), data)
|
||||
bar.Add(1)
|
||||
return nil
|
||||
// MockPWAClient is a mock implementation of the PWAClient interface.
|
||||
type MockPWAClient struct {
|
||||
ManifestURL string
|
||||
DN *datanode.DataNode
|
||||
Err error
|
||||
}
|
||||
|
||||
// NewMockPWAClient creates a new MockPWAClient.
|
||||
func NewMockPWAClient(manifestURL string, dn *datanode.DataNode, err error) PWAClient {
|
||||
return &MockPWAClient{
|
||||
ManifestURL: manifestURL,
|
||||
DN: dn,
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
// FindManifest mocks the finding of a PWA manifest.
|
||||
func (m *MockPWAClient) FindManifest(pwaURL string) (string, error) {
|
||||
return m.ManifestURL, m.Err
|
||||
}
|
||||
|
||||
// DownloadAndPackagePWA mocks the downloading and packaging of a PWA.
|
||||
func (m *MockPWAClient) DownloadAndPackagePWA(pwaURL, manifestURL string, bar *progressbar.ProgressBar) (*datanode.DataNode, error) {
|
||||
return m.DN, m.Err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,32 +1,15 @@
|
|||
package pwa
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/schollz/progressbar/v3"
|
||||
)
|
||||
|
||||
func newTestPWAClient(serverURL string) PWAClient {
|
||||
return &pwaClient{
|
||||
client: &http.Client{
|
||||
Transport: &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
DialContext: (&net.Dialer{
|
||||
Timeout: 30 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
}).DialContext,
|
||||
ForceAttemptHTTP2: true,
|
||||
MaxIdleConns: 100,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
},
|
||||
},
|
||||
}
|
||||
func newTestPWAClient() PWAClient {
|
||||
return NewPWAClient()
|
||||
}
|
||||
|
||||
func TestFindManifest(t *testing.T) {
|
||||
|
|
@ -47,7 +30,7 @@ func TestFindManifest(t *testing.T) {
|
|||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := newTestPWAClient(server.URL)
|
||||
client := newTestPWAClient()
|
||||
expectedURL := server.URL + "/manifest.json"
|
||||
actualURL, err := client.FindManifest(server.URL)
|
||||
if err != nil {
|
||||
|
|
@ -102,7 +85,7 @@ func TestDownloadAndPackagePWA(t *testing.T) {
|
|||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := newTestPWAClient(server.URL)
|
||||
client := newTestPWAClient()
|
||||
bar := progressbar.New(1)
|
||||
dn, err := client.DownloadAndPackagePWA(server.URL, server.URL+"/manifest.json", bar)
|
||||
if err != nil {
|
||||
|
|
@ -111,6 +94,7 @@ func TestDownloadAndPackagePWA(t *testing.T) {
|
|||
|
||||
expectedFiles := []string{"manifest.json", "index.html", "icon.png"}
|
||||
for _, file := range expectedFiles {
|
||||
// The path in the datanode is relative to the root of the domain, so we need to remove the leading slash.
|
||||
exists, err := dn.Exists(file)
|
||||
if err != nil {
|
||||
t.Fatalf("Exists failed for %s: %v", file, err)
|
||||
|
|
@ -122,7 +106,7 @@ func TestDownloadAndPackagePWA(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestResolveURL(t *testing.T) {
|
||||
client := NewPWAClient()
|
||||
client := NewPWAClient().(*pwaClient)
|
||||
tests := []struct {
|
||||
base string
|
||||
ref string
|
||||
|
|
@ -137,7 +121,7 @@ func TestResolveURL(t *testing.T) {
|
|||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
got, err := client.(*pwaClient).resolveURL(tt.base, tt.ref)
|
||||
got, err := client.resolveURL(tt.base, tt.ref)
|
||||
if err != nil {
|
||||
t.Errorf("resolveURL(%q, %q) returned error: %v", tt.base, tt.ref, err)
|
||||
continue
|
||||
|
|
@ -147,3 +131,35 @@ func TestResolveURL(t *testing.T) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPWA_Bad(t *testing.T) {
|
||||
client := NewPWAClient()
|
||||
|
||||
// Test FindManifest with no manifest
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/html")
|
||||
w.Write([]byte(`
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>Test PWA</title>
|
||||
</head>
|
||||
<body>
|
||||
<h1>Hello, PWA!</h1>
|
||||
</body>
|
||||
</html>
|
||||
`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
_, err := client.FindManifest(server.URL)
|
||||
if err == nil {
|
||||
t.Fatalf("expected an error, but got none")
|
||||
}
|
||||
|
||||
// Test DownloadAndPackagePWA with bad manifest
|
||||
_, err = client.DownloadAndPackagePWA(server.URL, server.URL+"/manifest.json", nil)
|
||||
if err == nil {
|
||||
t.Fatalf("expected an error, but got none")
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -67,23 +67,16 @@ type tarFile struct {
|
|||
modTime time.Time
|
||||
}
|
||||
|
||||
// Close implements http.File Close with a no-op for tar-backed files.
|
||||
func (f *tarFile) Close() error { return nil }
|
||||
|
||||
// Read implements io.Reader by delegating to the underlying bytes.Reader.
|
||||
func (f *tarFile) Close() error { return nil }
|
||||
func (f *tarFile) Read(p []byte) (int, error) { return f.content.Read(p) }
|
||||
|
||||
// Seek implements io.Seeker by delegating to the underlying bytes.Reader.
|
||||
func (f *tarFile) Seek(offset int64, whence int) (int64, error) {
|
||||
return f.content.Seek(offset, whence)
|
||||
}
|
||||
|
||||
// Readdir is unsupported for files in the tar filesystem and returns os.ErrInvalid.
|
||||
func (f *tarFile) Readdir(count int) ([]os.FileInfo, error) {
|
||||
return nil, os.ErrInvalid
|
||||
}
|
||||
|
||||
// Stat returns a FileInfo describing the tar-backed file.
|
||||
func (f *tarFile) Stat() (os.FileInfo, error) {
|
||||
return &tarFileInfo{
|
||||
name: path.Base(f.header.Name),
|
||||
|
|
@ -99,20 +92,9 @@ type tarFileInfo struct {
|
|||
modTime time.Time
|
||||
}
|
||||
|
||||
// Name returns the base name of the tar file.
|
||||
func (i *tarFileInfo) Name() string { return i.name }
|
||||
|
||||
// Size returns the size of the tar file in bytes.
|
||||
func (i *tarFileInfo) Size() int64 { return i.size }
|
||||
|
||||
// Mode returns a read-only file mode for tar entries.
|
||||
func (i *tarFileInfo) Mode() os.FileMode { return 0444 }
|
||||
|
||||
// ModTime returns the modification time recorded in the tar header.
|
||||
func (i *tarFileInfo) Name() string { return i.name }
|
||||
func (i *tarFileInfo) Size() int64 { return i.size }
|
||||
func (i *tarFileInfo) Mode() os.FileMode { return 0444 }
|
||||
func (i *tarFileInfo) ModTime() time.Time { return i.modTime }
|
||||
|
||||
// IsDir reports whether the entry is a directory (always false for files here).
|
||||
func (i *tarFileInfo) IsDir() bool { return false }
|
||||
|
||||
// Sys returns underlying data source (unused for tar entries).
|
||||
func (i *tarFileInfo) Sys() interface{} { return nil }
|
||||
func (i *tarFileInfo) IsDir() bool { return false }
|
||||
func (i *tarFileInfo) Sys() interface{} { return nil }
|
||||
|
|
|
|||
|
|
@ -10,8 +10,6 @@ import (
|
|||
"github.com/mattn/go-isatty"
|
||||
)
|
||||
|
||||
// NonInteractivePrompter periodically prints status quotes when the terminal is non-interactive.
|
||||
// It is intended to give the user feedback during long-running operations.
|
||||
type NonInteractivePrompter struct {
|
||||
stopChan chan struct{}
|
||||
quoteFunc func() (string, error)
|
||||
|
|
@ -20,8 +18,6 @@ type NonInteractivePrompter struct {
|
|||
stopOnce sync.Once
|
||||
}
|
||||
|
||||
// NewNonInteractivePrompter returns a new NonInteractivePrompter that will call
|
||||
// the provided quoteFunc to retrieve text to display.
|
||||
func NewNonInteractivePrompter(quoteFunc func() (string, error)) *NonInteractivePrompter {
|
||||
return &NonInteractivePrompter{
|
||||
stopChan: make(chan struct{}),
|
||||
|
|
@ -29,7 +25,6 @@ func NewNonInteractivePrompter(quoteFunc func() (string, error)) *NonInteractive
|
|||
}
|
||||
}
|
||||
|
||||
// Start begins the periodic display of quotes until Stop is called.
|
||||
func (p *NonInteractivePrompter) Start() {
|
||||
p.mu.Lock()
|
||||
if p.started {
|
||||
|
|
@ -64,7 +59,6 @@ func (p *NonInteractivePrompter) Start() {
|
|||
}()
|
||||
}
|
||||
|
||||
// Stop signals the prompter to stop printing further messages.
|
||||
func (p *NonInteractivePrompter) Stop() {
|
||||
if p.IsInteractive() {
|
||||
return
|
||||
|
|
@ -74,7 +68,6 @@ func (p *NonInteractivePrompter) Stop() {
|
|||
})
|
||||
}
|
||||
|
||||
// IsInteractive reports whether stdout is attached to an interactive terminal.
|
||||
func (p *NonInteractivePrompter) IsInteractive() bool {
|
||||
return isatty.IsTerminal(os.Stdout.Fd()) || isatty.IsCygwinTerminal(os.Stdout.Fd())
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,19 +1,16 @@
|
|||
|
||||
package ui
|
||||
|
||||
import "github.com/schollz/progressbar/v3"
|
||||
|
||||
// progressWriter implements io.Writer to update a progress bar with textual status.
|
||||
type progressWriter struct {
|
||||
bar *progressbar.ProgressBar
|
||||
}
|
||||
|
||||
// NewProgressWriter returns a writer that updates the provided progress bar's description.
|
||||
func NewProgressWriter(bar *progressbar.ProgressBar) *progressWriter {
|
||||
return &progressWriter{bar: bar}
|
||||
}
|
||||
|
||||
// Write updates the progress bar description with the provided bytes as a string.
|
||||
// It returns the length of p to satisfy io.Writer semantics.
|
||||
func (pw *progressWriter) Write(p []byte) (n int, err error) {
|
||||
if pw == nil || pw.bar == nil {
|
||||
return len(p), nil
|
||||
|
|
|
|||
|
|
@ -16,7 +16,6 @@ var (
|
|||
quotesErr error
|
||||
)
|
||||
|
||||
// init seeds the random number generator for quote selection.
|
||||
func init() {
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
}
|
||||
|
|
@ -42,7 +41,6 @@ type Quotes struct {
|
|||
} `json:"image_related"`
|
||||
}
|
||||
|
||||
// loadQuotes reads and unmarshals the embedded quotes.json file into a Quotes struct.
|
||||
func loadQuotes() (*Quotes, error) {
|
||||
quotesFile, err := QuotesJSON.ReadFile("quotes.json")
|
||||
if err != nil {
|
||||
|
|
@ -56,7 +54,6 @@ func loadQuotes() (*Quotes, error) {
|
|||
return "es, nil
|
||||
}
|
||||
|
||||
// getQuotes returns the cached Quotes, loading them once on first use.
|
||||
func getQuotes() (*Quotes, error) {
|
||||
quotesOnce.Do(func() {
|
||||
cachedQuotes, quotesErr = loadQuotes()
|
||||
|
|
@ -64,7 +61,6 @@ func getQuotes() (*Quotes, error) {
|
|||
return cachedQuotes, quotesErr
|
||||
}
|
||||
|
||||
// GetRandomQuote returns a random quote string from the combined quote sets.
|
||||
func GetRandomQuote() (string, error) {
|
||||
quotes, err := getQuotes()
|
||||
if err != nil {
|
||||
|
|
@ -86,7 +82,6 @@ func GetRandomQuote() (string, error) {
|
|||
return allQuotes[rand.Intn(len(allQuotes))], nil
|
||||
}
|
||||
|
||||
// PrintQuote prints a random quote to stdout in green for user feedback.
|
||||
func PrintQuote() {
|
||||
quote, err := GetRandomQuote()
|
||||
if err != nil {
|
||||
|
|
@ -97,7 +92,6 @@ func PrintQuote() {
|
|||
c.Println(quote)
|
||||
}
|
||||
|
||||
// GetVCSQuote returns a random quote related to VCS processing.
|
||||
func GetVCSQuote() (string, error) {
|
||||
quotes, err := getQuotes()
|
||||
if err != nil {
|
||||
|
|
@ -109,7 +103,6 @@ func GetVCSQuote() (string, error) {
|
|||
return quotes.VCSProcessing[rand.Intn(len(quotes.VCSProcessing))], nil
|
||||
}
|
||||
|
||||
// GetPWAQuote returns a random quote related to PWA processing.
|
||||
func GetPWAQuote() (string, error) {
|
||||
quotes, err := getQuotes()
|
||||
if err != nil {
|
||||
|
|
@ -121,7 +114,6 @@ func GetPWAQuote() (string, error) {
|
|||
return quotes.PWAProcessing[rand.Intn(len(quotes.PWAProcessing))], nil
|
||||
}
|
||||
|
||||
// GetWebsiteQuote returns a random quote related to website processing.
|
||||
func GetWebsiteQuote() (string, error) {
|
||||
quotes, err := getQuotes()
|
||||
if err != nil {
|
||||
|
|
|
|||
|
|
@ -13,6 +13,8 @@ import (
|
|||
"golang.org/x/net/html"
|
||||
)
|
||||
|
||||
var DownloadAndPackageWebsite = downloadAndPackageWebsite
|
||||
|
||||
// Downloader is a recursive website downloader.
|
||||
type Downloader struct {
|
||||
baseURL *url.URL
|
||||
|
|
@ -38,8 +40,8 @@ func NewDownloaderWithClient(maxDepth int, client *http.Client) *Downloader {
|
|||
}
|
||||
}
|
||||
|
||||
// DownloadAndPackageWebsite downloads a website and packages it into a DataNode.
|
||||
func DownloadAndPackageWebsite(startURL string, maxDepth int, bar *progressbar.ProgressBar) (*datanode.DataNode, error) {
|
||||
// downloadAndPackageWebsite downloads a website and packages it into a DataNode.
|
||||
func downloadAndPackageWebsite(startURL string, maxDepth int, bar *progressbar.ProgressBar) (*datanode.DataNode, error) {
|
||||
baseURL, err := url.Parse(startURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
@ -53,8 +55,6 @@ func DownloadAndPackageWebsite(startURL string, maxDepth int, bar *progressbar.P
|
|||
return d.dn, nil
|
||||
}
|
||||
|
||||
// crawl traverses the website starting from pageURL up to maxDepth,
|
||||
// storing pages and discovered local assets into the DataNode.
|
||||
func (d *Downloader) crawl(pageURL string, depth int) {
|
||||
if depth > d.maxDepth || d.visited[pageURL] {
|
||||
return
|
||||
|
|
@ -112,7 +112,6 @@ func (d *Downloader) crawl(pageURL string, depth int) {
|
|||
f(doc)
|
||||
}
|
||||
|
||||
// downloadAsset fetches a single asset URL and stores it into the DataNode.
|
||||
func (d *Downloader) downloadAsset(assetURL string) {
|
||||
if d.visited[assetURL] {
|
||||
return
|
||||
|
|
@ -139,7 +138,6 @@ func (d *Downloader) downloadAsset(assetURL string) {
|
|||
d.dn.AddData(relPath, body)
|
||||
}
|
||||
|
||||
// getRelativePath returns a path relative to the site's root for the given URL.
|
||||
func (d *Downloader) getRelativePath(pageURL string) string {
|
||||
u, err := url.Parse(pageURL)
|
||||
if err != nil {
|
||||
|
|
@ -148,7 +146,6 @@ func (d *Downloader) getRelativePath(pageURL string) string {
|
|||
return strings.TrimPrefix(u.Path, "/")
|
||||
}
|
||||
|
||||
// resolveURL resolves ref against base and returns the absolute URL string.
|
||||
func (d *Downloader) resolveURL(base, ref string) (string, error) {
|
||||
baseURL, err := url.Parse(base)
|
||||
if err != nil {
|
||||
|
|
@ -161,7 +158,6 @@ func (d *Downloader) resolveURL(base, ref string) (string, error) {
|
|||
return baseURL.ResolveReference(refURL).String(), nil
|
||||
}
|
||||
|
||||
// isLocal reports whether pageURL belongs to the same host as the base URL.
|
||||
func (d *Downloader) isLocal(pageURL string) bool {
|
||||
u, err := url.Parse(pageURL)
|
||||
if err != nil {
|
||||
|
|
@ -170,7 +166,6 @@ func (d *Downloader) isLocal(pageURL string) bool {
|
|||
return u.Hostname() == d.baseURL.Hostname()
|
||||
}
|
||||
|
||||
// isAsset reports whether the URL likely points to a static asset we should download.
|
||||
func isAsset(pageURL string) bool {
|
||||
ext := []string{".css", ".js", ".png", ".jpg", ".jpeg", ".gif", ".svg", ".ico"}
|
||||
for _, e := range ext {
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue