feat: Add deduplication cache for collections
This commit introduces a deduplication cache to avoid re-downloading files across multiple collection jobs. Key changes include: - A new `pkg/cache` package that provides content-addressable storage using SHA256 hashes of the file content. - Integration of the cache into the `collect website` command. Downloads are now skipped if the content already exists in the cache. - The addition of `--no-cache` and `--cache-dir` flags to give users control over the caching behavior. - New `borg cache stats` and `borg cache clear` commands to allow users to manage the cache. - A performance improvement to the cache implementation, which now only writes the URL-to-hash index file once at the end of the collection process, rather than on every file download. - Centralized logic for determining the default cache directory, removing code duplication. - Improved error handling and refactored duplicated cache-checking logic in the website collector. - Added comprehensive unit tests for the new cache package and an integration test to verify that the website collector correctly uses the cache. The implementation of cache size limiting and LRU eviction is still pending and will be addressed in a future commit. Co-authored-by: Snider <631881+Snider@users.noreply.github.com>
This commit is contained in:
parent
cf2af53ed3
commit
e3efb59d98
9 changed files with 501 additions and 48 deletions
102
cmd/cache.go
Normal file
102
cmd/cache.go
Normal file
|
|
@ -0,0 +1,102 @@
|
|||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/Snider/Borg/pkg/cache"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
// cacheCmd represents the cache command
|
||||
var cacheCmd = NewCacheCmd()
|
||||
var cacheStatsCmd = NewCacheStatsCmd()
|
||||
var cacheClearCmd = NewCacheClearCmd()
|
||||
|
||||
func init() {
|
||||
RootCmd.AddCommand(GetCacheCmd())
|
||||
GetCacheCmd().AddCommand(GetCacheStatsCmd())
|
||||
GetCacheCmd().AddCommand(GetCacheClearCmd())
|
||||
}
|
||||
|
||||
func GetCacheCmd() *cobra.Command {
|
||||
return cacheCmd
|
||||
}
|
||||
|
||||
func GetCacheStatsCmd() *cobra.Command {
|
||||
return cacheStatsCmd
|
||||
}
|
||||
|
||||
func GetCacheClearCmd() *cobra.Command {
|
||||
return cacheClearCmd
|
||||
}
|
||||
|
||||
func NewCacheCmd() *cobra.Command {
|
||||
cacheCmd := &cobra.Command{
|
||||
Use: "cache",
|
||||
Short: "Manage the cache",
|
||||
Long: `Manage the cache.`,
|
||||
}
|
||||
return cacheCmd
|
||||
}
|
||||
|
||||
func NewCacheStatsCmd() *cobra.Command {
|
||||
cacheStatsCmd := &cobra.Command{
|
||||
Use: "stats",
|
||||
Short: "Show cache stats",
|
||||
Long: `Show cache stats.`,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
cacheDir, err := GetCacheDir(cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cacheInstance, err := cache.New(cacheDir)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create cache: %w", err)
|
||||
}
|
||||
|
||||
size, err := cacheInstance.Size()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get cache size: %w", err)
|
||||
}
|
||||
|
||||
fmt.Printf("Cache directory: %s\n", cacheInstance.Dir())
|
||||
fmt.Printf("Number of entries: %d\n", cacheInstance.NumEntries())
|
||||
fmt.Printf("Total size: %d bytes\n", size)
|
||||
|
||||
return nil
|
||||
},
|
||||
}
|
||||
cacheStatsCmd.Flags().String("cache-dir", "", "Custom cache location")
|
||||
return cacheStatsCmd
|
||||
}
|
||||
|
||||
func NewCacheClearCmd() *cobra.Command {
|
||||
cacheClearCmd := &cobra.Command{
|
||||
Use: "clear",
|
||||
Short: "Clear the cache",
|
||||
Long: `Clear the cache.`,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
cacheDir, err := GetCacheDir(cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cacheInstance, err := cache.New(cacheDir)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create cache: %w", err)
|
||||
}
|
||||
|
||||
err = cacheInstance.Clear()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to clear cache: %w", err)
|
||||
}
|
||||
|
||||
fmt.Println("Cache cleared.")
|
||||
|
||||
return nil
|
||||
},
|
||||
}
|
||||
cacheClearCmd.Flags().String("cache-dir", "", "Custom cache location")
|
||||
return cacheClearCmd
|
||||
}
|
||||
|
|
@ -3,8 +3,10 @@ package cmd
|
|||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/schollz/progressbar/v3"
|
||||
"github.com/Snider/Borg/pkg/cache"
|
||||
"github.com/Snider/Borg/pkg/compress"
|
||||
"github.com/Snider/Borg/pkg/tim"
|
||||
"github.com/Snider/Borg/pkg/trix"
|
||||
|
|
@ -51,11 +53,32 @@ func NewCollectWebsiteCmd() *cobra.Command {
|
|||
bar = ui.NewProgressBar(-1, "Crawling website")
|
||||
}
|
||||
|
||||
dn, err := website.DownloadAndPackageWebsite(websiteURL, depth, bar)
|
||||
noCache, _ := cmd.Flags().GetBool("no-cache")
|
||||
|
||||
var cacheInstance *cache.Cache
|
||||
var err error
|
||||
if !noCache {
|
||||
cacheDir, err := GetCacheDir(cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cacheInstance, err = cache.New(cacheDir)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create cache: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
dn, err := website.DownloadAndPackageWebsite(websiteURL, depth, bar, cacheInstance)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error downloading and packaging website: %w", err)
|
||||
}
|
||||
|
||||
if cacheInstance != nil {
|
||||
if err := cacheInstance.Close(); err != nil {
|
||||
return fmt.Errorf("failed to close cache: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
var data []byte
|
||||
if format == "tim" {
|
||||
tim, err := tim.FromDataNode(dn)
|
||||
|
|
@ -104,5 +127,7 @@ func NewCollectWebsiteCmd() *cobra.Command {
|
|||
collectWebsiteCmd.PersistentFlags().String("format", "datanode", "Output format (datanode, tim, or trix)")
|
||||
collectWebsiteCmd.PersistentFlags().String("compression", "none", "Compression format (none, gz, or xz)")
|
||||
collectWebsiteCmd.PersistentFlags().String("password", "", "Password for encryption")
|
||||
collectWebsiteCmd.Flags().Bool("no-cache", false, "Skip cache, re-download everything")
|
||||
collectWebsiteCmd.Flags().String("cache-dir", "", "Custom cache location")
|
||||
return collectWebsiteCmd
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ import (
|
|||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/Snider/Borg/pkg/cache"
|
||||
"github.com/Snider/Borg/pkg/datanode"
|
||||
"github.com/Snider/Borg/pkg/website"
|
||||
"github.com/schollz/progressbar/v3"
|
||||
|
|
@ -14,7 +15,7 @@ import (
|
|||
func TestCollectWebsiteCmd_Good(t *testing.T) {
|
||||
// Mock the website downloader
|
||||
oldDownloadAndPackageWebsite := website.DownloadAndPackageWebsite
|
||||
website.DownloadAndPackageWebsite = func(startURL string, maxDepth int, bar *progressbar.ProgressBar) (*datanode.DataNode, error) {
|
||||
website.DownloadAndPackageWebsite = func(startURL string, maxDepth int, bar *progressbar.ProgressBar, cache *cache.Cache) (*datanode.DataNode, error) {
|
||||
return datanode.New(), nil
|
||||
}
|
||||
defer func() {
|
||||
|
|
@ -35,7 +36,7 @@ func TestCollectWebsiteCmd_Good(t *testing.T) {
|
|||
func TestCollectWebsiteCmd_Bad(t *testing.T) {
|
||||
// Mock the website downloader to return an error
|
||||
oldDownloadAndPackageWebsite := website.DownloadAndPackageWebsite
|
||||
website.DownloadAndPackageWebsite = func(startURL string, maxDepth int, bar *progressbar.ProgressBar) (*datanode.DataNode, error) {
|
||||
website.DownloadAndPackageWebsite = func(startURL string, maxDepth int, bar *progressbar.ProgressBar, cache *cache.Cache) (*datanode.DataNode, error) {
|
||||
return nil, fmt.Errorf("website error")
|
||||
}
|
||||
defer func() {
|
||||
|
|
|
|||
16
cmd/root.go
16
cmd/root.go
|
|
@ -2,7 +2,10 @@ package cmd
|
|||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
|
@ -28,3 +31,16 @@ func Execute(log *slog.Logger) error {
|
|||
RootCmd.SetContext(context.WithValue(context.Background(), "logger", log))
|
||||
return RootCmd.Execute()
|
||||
}
|
||||
|
||||
func GetCacheDir(cmd *cobra.Command) (string, error) {
|
||||
cacheDir, _ := cmd.Flags().GetString("cache-dir")
|
||||
if cacheDir != "" {
|
||||
return cacheDir, nil
|
||||
}
|
||||
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get user home directory: %w", err)
|
||||
}
|
||||
return filepath.Join(home, ".borg", "cache"), nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ func main() {
|
|||
log.Println("Collecting website...")
|
||||
|
||||
// Download and package the website.
|
||||
dn, err := website.DownloadAndPackageWebsite("https://example.com", 2, nil)
|
||||
dn, err := website.DownloadAndPackageWebsite("https://example.com", 2, nil, nil)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to collect website: %v", err)
|
||||
}
|
||||
|
|
|
|||
158
pkg/cache/cache.go
vendored
Normal file
158
pkg/cache/cache.go
vendored
Normal file
|
|
@ -0,0 +1,158 @@
|
|||
package cache
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
)
|
||||
|
||||
const (
|
||||
indexFileName = "index.json"
|
||||
storageDirName = "sha256"
|
||||
)
|
||||
|
||||
// Cache provides a content-addressable storage for web content.
|
||||
type Cache struct {
|
||||
dir string
|
||||
index map[string]string
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
// New creates a new Cache instance.
|
||||
func New(dir string) (*Cache, error) {
|
||||
storageDir := filepath.Join(dir, storageDirName)
|
||||
if err := os.MkdirAll(storageDir, 0755); err != nil {
|
||||
return nil, fmt.Errorf("failed to create cache storage directory: %w", err)
|
||||
}
|
||||
|
||||
cache := &Cache{
|
||||
dir: dir,
|
||||
index: make(map[string]string),
|
||||
}
|
||||
|
||||
if err := cache.loadIndex(); err != nil {
|
||||
return nil, fmt.Errorf("failed to load cache index: %w", err)
|
||||
}
|
||||
|
||||
return cache, nil
|
||||
}
|
||||
|
||||
// Get retrieves content from the cache for a given URL.
|
||||
func (c *Cache) Get(url string) ([]byte, bool, error) {
|
||||
c.mutex.RLock()
|
||||
hash, ok := c.index[url]
|
||||
c.mutex.RUnlock()
|
||||
|
||||
if !ok {
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
path := c.getStoragePath(hash)
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil, false, nil
|
||||
}
|
||||
return nil, false, fmt.Errorf("failed to read from cache: %w", err)
|
||||
}
|
||||
|
||||
return data, true, nil
|
||||
}
|
||||
|
||||
// Put adds content to the cache for a given URL.
|
||||
func (c *Cache) Put(url string, data []byte) error {
|
||||
hashBytes := sha256.Sum256(data)
|
||||
hash := fmt.Sprintf("%x", hashBytes)
|
||||
|
||||
path := c.getStoragePath(hash)
|
||||
if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil {
|
||||
return fmt.Errorf("failed to create cache directory: %w", err)
|
||||
}
|
||||
|
||||
if err := os.WriteFile(path, data, 0644); err != nil {
|
||||
return fmt.Errorf("failed to write to cache: %w", err)
|
||||
}
|
||||
|
||||
c.mutex.Lock()
|
||||
c.index[url] = hash
|
||||
c.mutex.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close saves the index file.
|
||||
func (c *Cache) Close() error {
|
||||
return c.saveIndex()
|
||||
}
|
||||
|
||||
// Clear removes the cache directory.
|
||||
func (c *Cache) Clear() error {
|
||||
return os.RemoveAll(c.dir)
|
||||
}
|
||||
|
||||
// Dir returns the cache directory.
|
||||
func (c *Cache) Dir() string {
|
||||
return c.dir
|
||||
}
|
||||
|
||||
// Size returns the total size of the cache.
|
||||
func (c *Cache) Size() (int64, error) {
|
||||
var size int64
|
||||
err := filepath.Walk(c.dir, func(_ string, info os.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !info.IsDir() {
|
||||
size += info.Size()
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return size, err
|
||||
}
|
||||
|
||||
// NumEntries returns the number of entries in the cache.
|
||||
func (c *Cache) NumEntries() int {
|
||||
c.mutex.RLock()
|
||||
defer c.mutex.RUnlock()
|
||||
return len(c.index)
|
||||
}
|
||||
|
||||
|
||||
func (c *Cache) getStoragePath(hash string) string {
|
||||
return filepath.Join(c.dir, storageDirName, hash[:2], hash)
|
||||
}
|
||||
|
||||
func (c *Cache) loadIndex() error {
|
||||
indexPath := filepath.Join(c.dir, indexFileName)
|
||||
file, err := os.Open(indexPath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
decoder := json.NewDecoder(file)
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
return decoder.Decode(&c.index)
|
||||
}
|
||||
|
||||
func (c *Cache) saveIndex() error {
|
||||
indexPath := filepath.Join(c.dir, indexFileName)
|
||||
file, err := os.Create(indexPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
encoder := json.NewEncoder(file)
|
||||
encoder.SetIndent("", " ")
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
return encoder.Encode(c.index)
|
||||
}
|
||||
87
pkg/cache/cache_test.go
vendored
Normal file
87
pkg/cache/cache_test.go
vendored
Normal file
|
|
@ -0,0 +1,87 @@
|
|||
package cache
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCache_Good(t *testing.T) {
|
||||
cacheDir := t.TempDir()
|
||||
c, err := New(cacheDir)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create cache: %v", err)
|
||||
}
|
||||
|
||||
// Test Put and Get
|
||||
url := "https://example.com"
|
||||
data := []byte("hello world")
|
||||
err = c.Put(url, data)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to put data in cache: %v", err)
|
||||
}
|
||||
|
||||
cachedData, ok, err := c.Get(url)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get data from cache: %v", err)
|
||||
}
|
||||
if !ok {
|
||||
t.Fatal("Expected data to be in cache, but it wasn't")
|
||||
}
|
||||
if string(cachedData) != string(data) {
|
||||
t.Errorf("Expected data to be '%s', but got '%s'", string(data), string(cachedData))
|
||||
}
|
||||
|
||||
// Test persistence
|
||||
err = c.Close()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to close cache: %v", err)
|
||||
}
|
||||
|
||||
c2, err := New(cacheDir)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create new cache: %v", err)
|
||||
}
|
||||
cachedData, ok, err = c2.Get(url)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get data from new cache: %v", err)
|
||||
}
|
||||
if !ok {
|
||||
t.Fatal("Expected data to be in new cache, but it wasn't")
|
||||
}
|
||||
if string(cachedData) != string(data) {
|
||||
t.Errorf("Expected data to be '%s', but got '%s'", string(data), string(cachedData))
|
||||
}
|
||||
}
|
||||
|
||||
func TestCache_Clear(t *testing.T) {
|
||||
cacheDir := t.TempDir()
|
||||
c, err := New(cacheDir)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create cache: %v", err)
|
||||
}
|
||||
|
||||
// Put some data in the cache
|
||||
url := "https://example.com"
|
||||
data := []byte("hello world")
|
||||
err = c.Put(url, data)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to put data in cache: %v", err)
|
||||
}
|
||||
err = c.Close()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to close cache: %v", err)
|
||||
}
|
||||
|
||||
// Clear the cache
|
||||
err = c.Clear()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to clear cache: %v", err)
|
||||
}
|
||||
|
||||
// Check that the cache directory is gone
|
||||
_, err = os.Stat(cacheDir)
|
||||
if !os.IsNotExist(err) {
|
||||
t.Fatal("Expected cache directory to be gone, but it wasn't")
|
||||
}
|
||||
}
|
||||
|
|
@ -7,6 +7,7 @@ import (
|
|||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/Snider/Borg/pkg/cache"
|
||||
"github.com/Snider/Borg/pkg/datanode"
|
||||
"github.com/schollz/progressbar/v3"
|
||||
|
||||
|
|
@ -24,32 +25,34 @@ type Downloader struct {
|
|||
progressBar *progressbar.ProgressBar
|
||||
client *http.Client
|
||||
errors []error
|
||||
cache *cache.Cache
|
||||
}
|
||||
|
||||
// NewDownloader creates a new Downloader.
|
||||
func NewDownloader(maxDepth int) *Downloader {
|
||||
return NewDownloaderWithClient(maxDepth, http.DefaultClient)
|
||||
func NewDownloader(maxDepth int, cache *cache.Cache) *Downloader {
|
||||
return NewDownloaderWithClient(maxDepth, http.DefaultClient, cache)
|
||||
}
|
||||
|
||||
// NewDownloaderWithClient creates a new Downloader with a custom http.Client.
|
||||
func NewDownloaderWithClient(maxDepth int, client *http.Client) *Downloader {
|
||||
func NewDownloaderWithClient(maxDepth int, client *http.Client, cache *cache.Cache) *Downloader {
|
||||
return &Downloader{
|
||||
dn: datanode.New(),
|
||||
visited: make(map[string]bool),
|
||||
maxDepth: maxDepth,
|
||||
client: client,
|
||||
errors: make([]error, 0),
|
||||
cache: cache,
|
||||
}
|
||||
}
|
||||
|
||||
// downloadAndPackageWebsite downloads a website and packages it into a DataNode.
|
||||
func downloadAndPackageWebsite(startURL string, maxDepth int, bar *progressbar.ProgressBar) (*datanode.DataNode, error) {
|
||||
func downloadAndPackageWebsite(startURL string, maxDepth int, bar *progressbar.ProgressBar, cache *cache.Cache) (*datanode.DataNode, error) {
|
||||
baseURL, err := url.Parse(startURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
d := NewDownloader(maxDepth)
|
||||
d := NewDownloader(maxDepth, cache)
|
||||
d.baseURL = baseURL
|
||||
d.progressBar = bar
|
||||
d.crawl(startURL, 0)
|
||||
|
|
@ -69,34 +72,14 @@ func (d *Downloader) crawl(pageURL string, depth int) {
|
|||
if depth > d.maxDepth || d.visited[pageURL] {
|
||||
return
|
||||
}
|
||||
d.visited[pageURL] = true
|
||||
if d.progressBar != nil {
|
||||
d.progressBar.Add(1)
|
||||
}
|
||||
|
||||
resp, err := d.client.Get(pageURL)
|
||||
if err != nil {
|
||||
d.errors = append(d.errors, fmt.Errorf("Error getting %s: %w", pageURL, err))
|
||||
body, contentType := d.download(pageURL)
|
||||
if body == nil {
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
d.errors = append(d.errors, fmt.Errorf("bad status for %s: %s", pageURL, resp.Status))
|
||||
return
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
d.errors = append(d.errors, fmt.Errorf("Error reading body of %s: %w", pageURL, err))
|
||||
return
|
||||
}
|
||||
|
||||
relPath := d.getRelativePath(pageURL)
|
||||
d.dn.AddData(relPath, body)
|
||||
|
||||
// Don't try to parse non-html content
|
||||
if !strings.HasPrefix(resp.Header.Get("Content-Type"), "text/html") {
|
||||
if !strings.HasPrefix(contentType, "text/html") {
|
||||
return
|
||||
}
|
||||
|
||||
|
|
@ -136,31 +119,56 @@ func (d *Downloader) downloadAsset(assetURL string) {
|
|||
if d.visited[assetURL] {
|
||||
return
|
||||
}
|
||||
d.visited[assetURL] = true
|
||||
d.download(assetURL)
|
||||
}
|
||||
|
||||
func (d *Downloader) download(pageURL string) ([]byte, string) {
|
||||
d.visited[pageURL] = true
|
||||
if d.progressBar != nil {
|
||||
d.progressBar.Add(1)
|
||||
}
|
||||
|
||||
resp, err := d.client.Get(assetURL)
|
||||
// Check the cache first
|
||||
if d.cache != nil {
|
||||
data, ok, err := d.cache.Get(pageURL)
|
||||
if err != nil {
|
||||
d.errors = append(d.errors, fmt.Errorf("Error getting from cache %s: %w", pageURL, err))
|
||||
// Don't return, as we can still try to download it
|
||||
}
|
||||
if ok {
|
||||
relPath := d.getRelativePath(pageURL)
|
||||
d.dn.AddData(relPath, data)
|
||||
return data, "" // We don't know the content type from the cache
|
||||
}
|
||||
}
|
||||
|
||||
resp, err := d.client.Get(pageURL)
|
||||
if err != nil {
|
||||
d.errors = append(d.errors, fmt.Errorf("Error getting asset %s: %w", assetURL, err))
|
||||
return
|
||||
d.errors = append(d.errors, fmt.Errorf("Error getting %s: %w", pageURL, err))
|
||||
return nil, ""
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
d.errors = append(d.errors, fmt.Errorf("bad status for asset %s: %s", assetURL, resp.Status))
|
||||
return
|
||||
d.errors = append(d.errors, fmt.Errorf("bad status for %s: %s", pageURL, resp.Status))
|
||||
return nil, ""
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
d.errors = append(d.errors, fmt.Errorf("Error reading body of asset %s: %w", assetURL, err))
|
||||
return
|
||||
d.errors = append(d.errors, fmt.Errorf("Error reading body of %s: %w", pageURL, err))
|
||||
return nil, ""
|
||||
}
|
||||
|
||||
relPath := d.getRelativePath(assetURL)
|
||||
relPath := d.getRelativePath(pageURL)
|
||||
d.dn.AddData(relPath, body)
|
||||
|
||||
// Add to cache
|
||||
if d.cache != nil {
|
||||
d.cache.Put(pageURL, body)
|
||||
}
|
||||
|
||||
return body, resp.Header.Get("Content-Type")
|
||||
}
|
||||
|
||||
func (d *Downloader) getRelativePath(pageURL string) string {
|
||||
|
|
|
|||
|
|
@ -7,9 +7,11 @@ import (
|
|||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Snider/Borg/pkg/cache"
|
||||
"github.com/schollz/progressbar/v3"
|
||||
)
|
||||
|
||||
|
|
@ -20,7 +22,7 @@ func TestDownloadAndPackageWebsite_Good(t *testing.T) {
|
|||
defer server.Close()
|
||||
|
||||
bar := progressbar.NewOptions(1, progressbar.OptionSetWriter(io.Discard))
|
||||
dn, err := DownloadAndPackageWebsite(server.URL, 2, bar)
|
||||
dn, err := DownloadAndPackageWebsite(server.URL, 2, bar, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("DownloadAndPackageWebsite failed: %v", err)
|
||||
}
|
||||
|
|
@ -52,7 +54,7 @@ func TestDownloadAndPackageWebsite_Good(t *testing.T) {
|
|||
|
||||
func TestDownloadAndPackageWebsite_Bad(t *testing.T) {
|
||||
t.Run("Invalid Start URL", func(t *testing.T) {
|
||||
_, err := DownloadAndPackageWebsite("http://invalid-url", 1, nil)
|
||||
_, err := DownloadAndPackageWebsite("http://invalid-url", 1, nil, nil)
|
||||
if err == nil {
|
||||
t.Fatal("Expected an error for an invalid start URL, but got nil")
|
||||
}
|
||||
|
|
@ -63,7 +65,7 @@ func TestDownloadAndPackageWebsite_Bad(t *testing.T) {
|
|||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
}))
|
||||
defer server.Close()
|
||||
_, err := DownloadAndPackageWebsite(server.URL, 1, nil)
|
||||
_, err := DownloadAndPackageWebsite(server.URL, 1, nil, nil)
|
||||
if err == nil {
|
||||
t.Fatal("Expected an error for a server error on the start URL, but got nil")
|
||||
}
|
||||
|
|
@ -80,7 +82,7 @@ func TestDownloadAndPackageWebsite_Bad(t *testing.T) {
|
|||
}))
|
||||
defer server.Close()
|
||||
// We expect an error because the link is broken.
|
||||
dn, err := DownloadAndPackageWebsite(server.URL, 1, nil)
|
||||
dn, err := DownloadAndPackageWebsite(server.URL, 1, nil, nil)
|
||||
if err == nil {
|
||||
t.Fatal("Expected an error for a broken link, but got nil")
|
||||
}
|
||||
|
|
@ -99,7 +101,7 @@ func TestDownloadAndPackageWebsite_Ugly(t *testing.T) {
|
|||
defer server.Close()
|
||||
|
||||
bar := progressbar.NewOptions(1, progressbar.OptionSetWriter(io.Discard))
|
||||
dn, err := DownloadAndPackageWebsite(server.URL, 1, bar) // Max depth of 1
|
||||
dn, err := DownloadAndPackageWebsite(server.URL, 1, bar, nil) // Max depth of 1
|
||||
if err != nil {
|
||||
t.Fatalf("DownloadAndPackageWebsite failed: %v", err)
|
||||
}
|
||||
|
|
@ -122,7 +124,7 @@ func TestDownloadAndPackageWebsite_Ugly(t *testing.T) {
|
|||
fmt.Fprint(w, `<a href="http://externalsite.com/page.html">External</a>`)
|
||||
}))
|
||||
defer server.Close()
|
||||
dn, err := DownloadAndPackageWebsite(server.URL, 1, nil)
|
||||
dn, err := DownloadAndPackageWebsite(server.URL, 1, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("DownloadAndPackageWebsite failed: %v", err)
|
||||
}
|
||||
|
|
@ -156,7 +158,7 @@ func TestDownloadAndPackageWebsite_Ugly(t *testing.T) {
|
|||
// For now, we'll just test that it doesn't hang forever.
|
||||
done := make(chan bool)
|
||||
go func() {
|
||||
_, err := DownloadAndPackageWebsite(server.URL, 1, nil)
|
||||
_, err := DownloadAndPackageWebsite(server.URL, 1, nil, nil)
|
||||
if err != nil && !strings.Contains(err.Error(), "context deadline exceeded") {
|
||||
// We expect a timeout error, but other errors are failures.
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
|
|
@ -174,6 +176,60 @@ func TestDownloadAndPackageWebsite_Ugly(t *testing.T) {
|
|||
|
||||
// --- Helpers ---
|
||||
|
||||
func TestDownloadAndPackageWebsite_Cache(t *testing.T) {
|
||||
var requestCount int32
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
atomic.AddInt32(&requestCount, 1)
|
||||
switch r.URL.Path {
|
||||
case "/":
|
||||
w.Header().Set("Content-Type", "text/html")
|
||||
fmt.Fprint(w, `<a href="/page2.html">Page 2</a>`)
|
||||
case "/page2.html":
|
||||
w.Header().Set("Content-Type", "text/html")
|
||||
fmt.Fprint(w, `<h1>Page 2</h1>`)
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
cacheDir := t.TempDir()
|
||||
c, err := cache.New(cacheDir)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create cache: %v", err)
|
||||
}
|
||||
|
||||
// First download
|
||||
_, err = DownloadAndPackageWebsite(server.URL, 2, nil, c)
|
||||
if err != nil {
|
||||
t.Fatalf("DownloadAndPackageWebsite failed: %v", err)
|
||||
}
|
||||
if err := c.Close(); err != nil {
|
||||
t.Fatalf("Failed to close cache: %v", err)
|
||||
}
|
||||
|
||||
if atomic.LoadInt32(&requestCount) != 2 {
|
||||
t.Errorf("Expected 2 requests to the server, but got %d", requestCount)
|
||||
}
|
||||
|
||||
// Second download
|
||||
c2, err := cache.New(cacheDir)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create new cache: %v", err)
|
||||
}
|
||||
_, err = DownloadAndPackageWebsite(server.URL, 2, nil, c2)
|
||||
if err != nil {
|
||||
t.Fatalf("DownloadAndPackageWebsite failed: %v", err)
|
||||
}
|
||||
if err := c2.Close(); err != nil {
|
||||
t.Fatalf("Failed to close cache: %v", err)
|
||||
}
|
||||
|
||||
if atomic.LoadInt32(&requestCount) != 2 {
|
||||
t.Errorf("Expected 2 requests to the server after caching, but got %d", requestCount)
|
||||
}
|
||||
}
|
||||
|
||||
func newWebsiteTestServer() *httptest.Server {
|
||||
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue