Compare commits
1 commit
main
...
feat/dedup
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e3efb59d98 |
9 changed files with 501 additions and 48 deletions
102
cmd/cache.go
Normal file
102
cmd/cache.go
Normal file
|
|
@ -0,0 +1,102 @@
|
|||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/Snider/Borg/pkg/cache"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
// cacheCmd represents the cache command
|
||||
var cacheCmd = NewCacheCmd()
|
||||
var cacheStatsCmd = NewCacheStatsCmd()
|
||||
var cacheClearCmd = NewCacheClearCmd()
|
||||
|
||||
func init() {
|
||||
RootCmd.AddCommand(GetCacheCmd())
|
||||
GetCacheCmd().AddCommand(GetCacheStatsCmd())
|
||||
GetCacheCmd().AddCommand(GetCacheClearCmd())
|
||||
}
|
||||
|
||||
func GetCacheCmd() *cobra.Command {
|
||||
return cacheCmd
|
||||
}
|
||||
|
||||
func GetCacheStatsCmd() *cobra.Command {
|
||||
return cacheStatsCmd
|
||||
}
|
||||
|
||||
func GetCacheClearCmd() *cobra.Command {
|
||||
return cacheClearCmd
|
||||
}
|
||||
|
||||
func NewCacheCmd() *cobra.Command {
|
||||
cacheCmd := &cobra.Command{
|
||||
Use: "cache",
|
||||
Short: "Manage the cache",
|
||||
Long: `Manage the cache.`,
|
||||
}
|
||||
return cacheCmd
|
||||
}
|
||||
|
||||
func NewCacheStatsCmd() *cobra.Command {
|
||||
cacheStatsCmd := &cobra.Command{
|
||||
Use: "stats",
|
||||
Short: "Show cache stats",
|
||||
Long: `Show cache stats.`,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
cacheDir, err := GetCacheDir(cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cacheInstance, err := cache.New(cacheDir)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create cache: %w", err)
|
||||
}
|
||||
|
||||
size, err := cacheInstance.Size()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get cache size: %w", err)
|
||||
}
|
||||
|
||||
fmt.Printf("Cache directory: %s\n", cacheInstance.Dir())
|
||||
fmt.Printf("Number of entries: %d\n", cacheInstance.NumEntries())
|
||||
fmt.Printf("Total size: %d bytes\n", size)
|
||||
|
||||
return nil
|
||||
},
|
||||
}
|
||||
cacheStatsCmd.Flags().String("cache-dir", "", "Custom cache location")
|
||||
return cacheStatsCmd
|
||||
}
|
||||
|
||||
func NewCacheClearCmd() *cobra.Command {
|
||||
cacheClearCmd := &cobra.Command{
|
||||
Use: "clear",
|
||||
Short: "Clear the cache",
|
||||
Long: `Clear the cache.`,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
cacheDir, err := GetCacheDir(cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cacheInstance, err := cache.New(cacheDir)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create cache: %w", err)
|
||||
}
|
||||
|
||||
err = cacheInstance.Clear()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to clear cache: %w", err)
|
||||
}
|
||||
|
||||
fmt.Println("Cache cleared.")
|
||||
|
||||
return nil
|
||||
},
|
||||
}
|
||||
cacheClearCmd.Flags().String("cache-dir", "", "Custom cache location")
|
||||
return cacheClearCmd
|
||||
}
|
||||
|
|
@ -3,8 +3,10 @@ package cmd
|
|||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/schollz/progressbar/v3"
|
||||
"github.com/Snider/Borg/pkg/cache"
|
||||
"github.com/Snider/Borg/pkg/compress"
|
||||
"github.com/Snider/Borg/pkg/tim"
|
||||
"github.com/Snider/Borg/pkg/trix"
|
||||
|
|
@ -51,11 +53,32 @@ func NewCollectWebsiteCmd() *cobra.Command {
|
|||
bar = ui.NewProgressBar(-1, "Crawling website")
|
||||
}
|
||||
|
||||
dn, err := website.DownloadAndPackageWebsite(websiteURL, depth, bar)
|
||||
noCache, _ := cmd.Flags().GetBool("no-cache")
|
||||
|
||||
var cacheInstance *cache.Cache
|
||||
var err error
|
||||
if !noCache {
|
||||
cacheDir, err := GetCacheDir(cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cacheInstance, err = cache.New(cacheDir)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create cache: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
dn, err := website.DownloadAndPackageWebsite(websiteURL, depth, bar, cacheInstance)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error downloading and packaging website: %w", err)
|
||||
}
|
||||
|
||||
if cacheInstance != nil {
|
||||
if err := cacheInstance.Close(); err != nil {
|
||||
return fmt.Errorf("failed to close cache: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
var data []byte
|
||||
if format == "tim" {
|
||||
tim, err := tim.FromDataNode(dn)
|
||||
|
|
@ -104,5 +127,7 @@ func NewCollectWebsiteCmd() *cobra.Command {
|
|||
collectWebsiteCmd.PersistentFlags().String("format", "datanode", "Output format (datanode, tim, or trix)")
|
||||
collectWebsiteCmd.PersistentFlags().String("compression", "none", "Compression format (none, gz, or xz)")
|
||||
collectWebsiteCmd.PersistentFlags().String("password", "", "Password for encryption")
|
||||
collectWebsiteCmd.Flags().Bool("no-cache", false, "Skip cache, re-download everything")
|
||||
collectWebsiteCmd.Flags().String("cache-dir", "", "Custom cache location")
|
||||
return collectWebsiteCmd
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ import (
|
|||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/Snider/Borg/pkg/cache"
|
||||
"github.com/Snider/Borg/pkg/datanode"
|
||||
"github.com/Snider/Borg/pkg/website"
|
||||
"github.com/schollz/progressbar/v3"
|
||||
|
|
@ -14,7 +15,7 @@ import (
|
|||
func TestCollectWebsiteCmd_Good(t *testing.T) {
|
||||
// Mock the website downloader
|
||||
oldDownloadAndPackageWebsite := website.DownloadAndPackageWebsite
|
||||
website.DownloadAndPackageWebsite = func(startURL string, maxDepth int, bar *progressbar.ProgressBar) (*datanode.DataNode, error) {
|
||||
website.DownloadAndPackageWebsite = func(startURL string, maxDepth int, bar *progressbar.ProgressBar, cache *cache.Cache) (*datanode.DataNode, error) {
|
||||
return datanode.New(), nil
|
||||
}
|
||||
defer func() {
|
||||
|
|
@ -35,7 +36,7 @@ func TestCollectWebsiteCmd_Good(t *testing.T) {
|
|||
func TestCollectWebsiteCmd_Bad(t *testing.T) {
|
||||
// Mock the website downloader to return an error
|
||||
oldDownloadAndPackageWebsite := website.DownloadAndPackageWebsite
|
||||
website.DownloadAndPackageWebsite = func(startURL string, maxDepth int, bar *progressbar.ProgressBar) (*datanode.DataNode, error) {
|
||||
website.DownloadAndPackageWebsite = func(startURL string, maxDepth int, bar *progressbar.ProgressBar, cache *cache.Cache) (*datanode.DataNode, error) {
|
||||
return nil, fmt.Errorf("website error")
|
||||
}
|
||||
defer func() {
|
||||
|
|
|
|||
16
cmd/root.go
16
cmd/root.go
|
|
@ -2,7 +2,10 @@ package cmd
|
|||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
|
@ -28,3 +31,16 @@ func Execute(log *slog.Logger) error {
|
|||
RootCmd.SetContext(context.WithValue(context.Background(), "logger", log))
|
||||
return RootCmd.Execute()
|
||||
}
|
||||
|
||||
func GetCacheDir(cmd *cobra.Command) (string, error) {
|
||||
cacheDir, _ := cmd.Flags().GetString("cache-dir")
|
||||
if cacheDir != "" {
|
||||
return cacheDir, nil
|
||||
}
|
||||
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get user home directory: %w", err)
|
||||
}
|
||||
return filepath.Join(home, ".borg", "cache"), nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ func main() {
|
|||
log.Println("Collecting website...")
|
||||
|
||||
// Download and package the website.
|
||||
dn, err := website.DownloadAndPackageWebsite("https://example.com", 2, nil)
|
||||
dn, err := website.DownloadAndPackageWebsite("https://example.com", 2, nil, nil)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to collect website: %v", err)
|
||||
}
|
||||
|
|
|
|||
158
pkg/cache/cache.go
vendored
Normal file
158
pkg/cache/cache.go
vendored
Normal file
|
|
@ -0,0 +1,158 @@
|
|||
package cache
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
)
|
||||
|
||||
const (
|
||||
indexFileName = "index.json"
|
||||
storageDirName = "sha256"
|
||||
)
|
||||
|
||||
// Cache provides a content-addressable storage for web content.
|
||||
type Cache struct {
|
||||
dir string
|
||||
index map[string]string
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
// New creates a new Cache instance.
|
||||
func New(dir string) (*Cache, error) {
|
||||
storageDir := filepath.Join(dir, storageDirName)
|
||||
if err := os.MkdirAll(storageDir, 0755); err != nil {
|
||||
return nil, fmt.Errorf("failed to create cache storage directory: %w", err)
|
||||
}
|
||||
|
||||
cache := &Cache{
|
||||
dir: dir,
|
||||
index: make(map[string]string),
|
||||
}
|
||||
|
||||
if err := cache.loadIndex(); err != nil {
|
||||
return nil, fmt.Errorf("failed to load cache index: %w", err)
|
||||
}
|
||||
|
||||
return cache, nil
|
||||
}
|
||||
|
||||
// Get retrieves content from the cache for a given URL.
|
||||
func (c *Cache) Get(url string) ([]byte, bool, error) {
|
||||
c.mutex.RLock()
|
||||
hash, ok := c.index[url]
|
||||
c.mutex.RUnlock()
|
||||
|
||||
if !ok {
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
path := c.getStoragePath(hash)
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil, false, nil
|
||||
}
|
||||
return nil, false, fmt.Errorf("failed to read from cache: %w", err)
|
||||
}
|
||||
|
||||
return data, true, nil
|
||||
}
|
||||
|
||||
// Put adds content to the cache for a given URL.
|
||||
func (c *Cache) Put(url string, data []byte) error {
|
||||
hashBytes := sha256.Sum256(data)
|
||||
hash := fmt.Sprintf("%x", hashBytes)
|
||||
|
||||
path := c.getStoragePath(hash)
|
||||
if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil {
|
||||
return fmt.Errorf("failed to create cache directory: %w", err)
|
||||
}
|
||||
|
||||
if err := os.WriteFile(path, data, 0644); err != nil {
|
||||
return fmt.Errorf("failed to write to cache: %w", err)
|
||||
}
|
||||
|
||||
c.mutex.Lock()
|
||||
c.index[url] = hash
|
||||
c.mutex.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close saves the index file.
|
||||
func (c *Cache) Close() error {
|
||||
return c.saveIndex()
|
||||
}
|
||||
|
||||
// Clear removes the cache directory.
|
||||
func (c *Cache) Clear() error {
|
||||
return os.RemoveAll(c.dir)
|
||||
}
|
||||
|
||||
// Dir returns the cache directory.
|
||||
func (c *Cache) Dir() string {
|
||||
return c.dir
|
||||
}
|
||||
|
||||
// Size returns the total size of the cache.
|
||||
func (c *Cache) Size() (int64, error) {
|
||||
var size int64
|
||||
err := filepath.Walk(c.dir, func(_ string, info os.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !info.IsDir() {
|
||||
size += info.Size()
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return size, err
|
||||
}
|
||||
|
||||
// NumEntries returns the number of entries in the cache.
|
||||
func (c *Cache) NumEntries() int {
|
||||
c.mutex.RLock()
|
||||
defer c.mutex.RUnlock()
|
||||
return len(c.index)
|
||||
}
|
||||
|
||||
|
||||
func (c *Cache) getStoragePath(hash string) string {
|
||||
return filepath.Join(c.dir, storageDirName, hash[:2], hash)
|
||||
}
|
||||
|
||||
func (c *Cache) loadIndex() error {
|
||||
indexPath := filepath.Join(c.dir, indexFileName)
|
||||
file, err := os.Open(indexPath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
decoder := json.NewDecoder(file)
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
return decoder.Decode(&c.index)
|
||||
}
|
||||
|
||||
func (c *Cache) saveIndex() error {
|
||||
indexPath := filepath.Join(c.dir, indexFileName)
|
||||
file, err := os.Create(indexPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
encoder := json.NewEncoder(file)
|
||||
encoder.SetIndent("", " ")
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
return encoder.Encode(c.index)
|
||||
}
|
||||
87
pkg/cache/cache_test.go
vendored
Normal file
87
pkg/cache/cache_test.go
vendored
Normal file
|
|
@ -0,0 +1,87 @@
|
|||
package cache
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCache_Good(t *testing.T) {
|
||||
cacheDir := t.TempDir()
|
||||
c, err := New(cacheDir)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create cache: %v", err)
|
||||
}
|
||||
|
||||
// Test Put and Get
|
||||
url := "https://example.com"
|
||||
data := []byte("hello world")
|
||||
err = c.Put(url, data)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to put data in cache: %v", err)
|
||||
}
|
||||
|
||||
cachedData, ok, err := c.Get(url)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get data from cache: %v", err)
|
||||
}
|
||||
if !ok {
|
||||
t.Fatal("Expected data to be in cache, but it wasn't")
|
||||
}
|
||||
if string(cachedData) != string(data) {
|
||||
t.Errorf("Expected data to be '%s', but got '%s'", string(data), string(cachedData))
|
||||
}
|
||||
|
||||
// Test persistence
|
||||
err = c.Close()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to close cache: %v", err)
|
||||
}
|
||||
|
||||
c2, err := New(cacheDir)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create new cache: %v", err)
|
||||
}
|
||||
cachedData, ok, err = c2.Get(url)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get data from new cache: %v", err)
|
||||
}
|
||||
if !ok {
|
||||
t.Fatal("Expected data to be in new cache, but it wasn't")
|
||||
}
|
||||
if string(cachedData) != string(data) {
|
||||
t.Errorf("Expected data to be '%s', but got '%s'", string(data), string(cachedData))
|
||||
}
|
||||
}
|
||||
|
||||
func TestCache_Clear(t *testing.T) {
|
||||
cacheDir := t.TempDir()
|
||||
c, err := New(cacheDir)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create cache: %v", err)
|
||||
}
|
||||
|
||||
// Put some data in the cache
|
||||
url := "https://example.com"
|
||||
data := []byte("hello world")
|
||||
err = c.Put(url, data)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to put data in cache: %v", err)
|
||||
}
|
||||
err = c.Close()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to close cache: %v", err)
|
||||
}
|
||||
|
||||
// Clear the cache
|
||||
err = c.Clear()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to clear cache: %v", err)
|
||||
}
|
||||
|
||||
// Check that the cache directory is gone
|
||||
_, err = os.Stat(cacheDir)
|
||||
if !os.IsNotExist(err) {
|
||||
t.Fatal("Expected cache directory to be gone, but it wasn't")
|
||||
}
|
||||
}
|
||||
|
|
@ -7,6 +7,7 @@ import (
|
|||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/Snider/Borg/pkg/cache"
|
||||
"github.com/Snider/Borg/pkg/datanode"
|
||||
"github.com/schollz/progressbar/v3"
|
||||
|
||||
|
|
@ -24,32 +25,34 @@ type Downloader struct {
|
|||
progressBar *progressbar.ProgressBar
|
||||
client *http.Client
|
||||
errors []error
|
||||
cache *cache.Cache
|
||||
}
|
||||
|
||||
// NewDownloader creates a new Downloader.
|
||||
func NewDownloader(maxDepth int) *Downloader {
|
||||
return NewDownloaderWithClient(maxDepth, http.DefaultClient)
|
||||
func NewDownloader(maxDepth int, cache *cache.Cache) *Downloader {
|
||||
return NewDownloaderWithClient(maxDepth, http.DefaultClient, cache)
|
||||
}
|
||||
|
||||
// NewDownloaderWithClient creates a new Downloader with a custom http.Client.
|
||||
func NewDownloaderWithClient(maxDepth int, client *http.Client) *Downloader {
|
||||
func NewDownloaderWithClient(maxDepth int, client *http.Client, cache *cache.Cache) *Downloader {
|
||||
return &Downloader{
|
||||
dn: datanode.New(),
|
||||
visited: make(map[string]bool),
|
||||
maxDepth: maxDepth,
|
||||
client: client,
|
||||
errors: make([]error, 0),
|
||||
cache: cache,
|
||||
}
|
||||
}
|
||||
|
||||
// downloadAndPackageWebsite downloads a website and packages it into a DataNode.
|
||||
func downloadAndPackageWebsite(startURL string, maxDepth int, bar *progressbar.ProgressBar) (*datanode.DataNode, error) {
|
||||
func downloadAndPackageWebsite(startURL string, maxDepth int, bar *progressbar.ProgressBar, cache *cache.Cache) (*datanode.DataNode, error) {
|
||||
baseURL, err := url.Parse(startURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
d := NewDownloader(maxDepth)
|
||||
d := NewDownloader(maxDepth, cache)
|
||||
d.baseURL = baseURL
|
||||
d.progressBar = bar
|
||||
d.crawl(startURL, 0)
|
||||
|
|
@ -69,34 +72,14 @@ func (d *Downloader) crawl(pageURL string, depth int) {
|
|||
if depth > d.maxDepth || d.visited[pageURL] {
|
||||
return
|
||||
}
|
||||
d.visited[pageURL] = true
|
||||
if d.progressBar != nil {
|
||||
d.progressBar.Add(1)
|
||||
}
|
||||
|
||||
resp, err := d.client.Get(pageURL)
|
||||
if err != nil {
|
||||
d.errors = append(d.errors, fmt.Errorf("Error getting %s: %w", pageURL, err))
|
||||
body, contentType := d.download(pageURL)
|
||||
if body == nil {
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
d.errors = append(d.errors, fmt.Errorf("bad status for %s: %s", pageURL, resp.Status))
|
||||
return
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
d.errors = append(d.errors, fmt.Errorf("Error reading body of %s: %w", pageURL, err))
|
||||
return
|
||||
}
|
||||
|
||||
relPath := d.getRelativePath(pageURL)
|
||||
d.dn.AddData(relPath, body)
|
||||
|
||||
// Don't try to parse non-html content
|
||||
if !strings.HasPrefix(resp.Header.Get("Content-Type"), "text/html") {
|
||||
if !strings.HasPrefix(contentType, "text/html") {
|
||||
return
|
||||
}
|
||||
|
||||
|
|
@ -136,31 +119,56 @@ func (d *Downloader) downloadAsset(assetURL string) {
|
|||
if d.visited[assetURL] {
|
||||
return
|
||||
}
|
||||
d.visited[assetURL] = true
|
||||
d.download(assetURL)
|
||||
}
|
||||
|
||||
func (d *Downloader) download(pageURL string) ([]byte, string) {
|
||||
d.visited[pageURL] = true
|
||||
if d.progressBar != nil {
|
||||
d.progressBar.Add(1)
|
||||
}
|
||||
|
||||
resp, err := d.client.Get(assetURL)
|
||||
// Check the cache first
|
||||
if d.cache != nil {
|
||||
data, ok, err := d.cache.Get(pageURL)
|
||||
if err != nil {
|
||||
d.errors = append(d.errors, fmt.Errorf("Error getting from cache %s: %w", pageURL, err))
|
||||
// Don't return, as we can still try to download it
|
||||
}
|
||||
if ok {
|
||||
relPath := d.getRelativePath(pageURL)
|
||||
d.dn.AddData(relPath, data)
|
||||
return data, "" // We don't know the content type from the cache
|
||||
}
|
||||
}
|
||||
|
||||
resp, err := d.client.Get(pageURL)
|
||||
if err != nil {
|
||||
d.errors = append(d.errors, fmt.Errorf("Error getting asset %s: %w", assetURL, err))
|
||||
return
|
||||
d.errors = append(d.errors, fmt.Errorf("Error getting %s: %w", pageURL, err))
|
||||
return nil, ""
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
d.errors = append(d.errors, fmt.Errorf("bad status for asset %s: %s", assetURL, resp.Status))
|
||||
return
|
||||
d.errors = append(d.errors, fmt.Errorf("bad status for %s: %s", pageURL, resp.Status))
|
||||
return nil, ""
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
d.errors = append(d.errors, fmt.Errorf("Error reading body of asset %s: %w", assetURL, err))
|
||||
return
|
||||
d.errors = append(d.errors, fmt.Errorf("Error reading body of %s: %w", pageURL, err))
|
||||
return nil, ""
|
||||
}
|
||||
|
||||
relPath := d.getRelativePath(assetURL)
|
||||
relPath := d.getRelativePath(pageURL)
|
||||
d.dn.AddData(relPath, body)
|
||||
|
||||
// Add to cache
|
||||
if d.cache != nil {
|
||||
d.cache.Put(pageURL, body)
|
||||
}
|
||||
|
||||
return body, resp.Header.Get("Content-Type")
|
||||
}
|
||||
|
||||
func (d *Downloader) getRelativePath(pageURL string) string {
|
||||
|
|
|
|||
|
|
@ -7,9 +7,11 @@ import (
|
|||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Snider/Borg/pkg/cache"
|
||||
"github.com/schollz/progressbar/v3"
|
||||
)
|
||||
|
||||
|
|
@ -20,7 +22,7 @@ func TestDownloadAndPackageWebsite_Good(t *testing.T) {
|
|||
defer server.Close()
|
||||
|
||||
bar := progressbar.NewOptions(1, progressbar.OptionSetWriter(io.Discard))
|
||||
dn, err := DownloadAndPackageWebsite(server.URL, 2, bar)
|
||||
dn, err := DownloadAndPackageWebsite(server.URL, 2, bar, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("DownloadAndPackageWebsite failed: %v", err)
|
||||
}
|
||||
|
|
@ -52,7 +54,7 @@ func TestDownloadAndPackageWebsite_Good(t *testing.T) {
|
|||
|
||||
func TestDownloadAndPackageWebsite_Bad(t *testing.T) {
|
||||
t.Run("Invalid Start URL", func(t *testing.T) {
|
||||
_, err := DownloadAndPackageWebsite("http://invalid-url", 1, nil)
|
||||
_, err := DownloadAndPackageWebsite("http://invalid-url", 1, nil, nil)
|
||||
if err == nil {
|
||||
t.Fatal("Expected an error for an invalid start URL, but got nil")
|
||||
}
|
||||
|
|
@ -63,7 +65,7 @@ func TestDownloadAndPackageWebsite_Bad(t *testing.T) {
|
|||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
}))
|
||||
defer server.Close()
|
||||
_, err := DownloadAndPackageWebsite(server.URL, 1, nil)
|
||||
_, err := DownloadAndPackageWebsite(server.URL, 1, nil, nil)
|
||||
if err == nil {
|
||||
t.Fatal("Expected an error for a server error on the start URL, but got nil")
|
||||
}
|
||||
|
|
@ -80,7 +82,7 @@ func TestDownloadAndPackageWebsite_Bad(t *testing.T) {
|
|||
}))
|
||||
defer server.Close()
|
||||
// We expect an error because the link is broken.
|
||||
dn, err := DownloadAndPackageWebsite(server.URL, 1, nil)
|
||||
dn, err := DownloadAndPackageWebsite(server.URL, 1, nil, nil)
|
||||
if err == nil {
|
||||
t.Fatal("Expected an error for a broken link, but got nil")
|
||||
}
|
||||
|
|
@ -99,7 +101,7 @@ func TestDownloadAndPackageWebsite_Ugly(t *testing.T) {
|
|||
defer server.Close()
|
||||
|
||||
bar := progressbar.NewOptions(1, progressbar.OptionSetWriter(io.Discard))
|
||||
dn, err := DownloadAndPackageWebsite(server.URL, 1, bar) // Max depth of 1
|
||||
dn, err := DownloadAndPackageWebsite(server.URL, 1, bar, nil) // Max depth of 1
|
||||
if err != nil {
|
||||
t.Fatalf("DownloadAndPackageWebsite failed: %v", err)
|
||||
}
|
||||
|
|
@ -122,7 +124,7 @@ func TestDownloadAndPackageWebsite_Ugly(t *testing.T) {
|
|||
fmt.Fprint(w, `<a href="http://externalsite.com/page.html">External</a>`)
|
||||
}))
|
||||
defer server.Close()
|
||||
dn, err := DownloadAndPackageWebsite(server.URL, 1, nil)
|
||||
dn, err := DownloadAndPackageWebsite(server.URL, 1, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("DownloadAndPackageWebsite failed: %v", err)
|
||||
}
|
||||
|
|
@ -156,7 +158,7 @@ func TestDownloadAndPackageWebsite_Ugly(t *testing.T) {
|
|||
// For now, we'll just test that it doesn't hang forever.
|
||||
done := make(chan bool)
|
||||
go func() {
|
||||
_, err := DownloadAndPackageWebsite(server.URL, 1, nil)
|
||||
_, err := DownloadAndPackageWebsite(server.URL, 1, nil, nil)
|
||||
if err != nil && !strings.Contains(err.Error(), "context deadline exceeded") {
|
||||
// We expect a timeout error, but other errors are failures.
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
|
|
@ -174,6 +176,60 @@ func TestDownloadAndPackageWebsite_Ugly(t *testing.T) {
|
|||
|
||||
// --- Helpers ---
|
||||
|
||||
func TestDownloadAndPackageWebsite_Cache(t *testing.T) {
|
||||
var requestCount int32
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
atomic.AddInt32(&requestCount, 1)
|
||||
switch r.URL.Path {
|
||||
case "/":
|
||||
w.Header().Set("Content-Type", "text/html")
|
||||
fmt.Fprint(w, `<a href="/page2.html">Page 2</a>`)
|
||||
case "/page2.html":
|
||||
w.Header().Set("Content-Type", "text/html")
|
||||
fmt.Fprint(w, `<h1>Page 2</h1>`)
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
cacheDir := t.TempDir()
|
||||
c, err := cache.New(cacheDir)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create cache: %v", err)
|
||||
}
|
||||
|
||||
// First download
|
||||
_, err = DownloadAndPackageWebsite(server.URL, 2, nil, c)
|
||||
if err != nil {
|
||||
t.Fatalf("DownloadAndPackageWebsite failed: %v", err)
|
||||
}
|
||||
if err := c.Close(); err != nil {
|
||||
t.Fatalf("Failed to close cache: %v", err)
|
||||
}
|
||||
|
||||
if atomic.LoadInt32(&requestCount) != 2 {
|
||||
t.Errorf("Expected 2 requests to the server, but got %d", requestCount)
|
||||
}
|
||||
|
||||
// Second download
|
||||
c2, err := cache.New(cacheDir)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create new cache: %v", err)
|
||||
}
|
||||
_, err = DownloadAndPackageWebsite(server.URL, 2, nil, c2)
|
||||
if err != nil {
|
||||
t.Fatalf("DownloadAndPackageWebsite failed: %v", err)
|
||||
}
|
||||
if err := c2.Close(); err != nil {
|
||||
t.Fatalf("Failed to close cache: %v", err)
|
||||
}
|
||||
|
||||
if atomic.LoadInt32(&requestCount) != 2 {
|
||||
t.Errorf("Expected 2 requests to the server after caching, but got %d", requestCount)
|
||||
}
|
||||
}
|
||||
|
||||
func newWebsiteTestServer() *httptest.Server {
|
||||
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue