diff --git a/cmd/cache.go b/cmd/cache.go new file mode 100644 index 0000000..f249188 --- /dev/null +++ b/cmd/cache.go @@ -0,0 +1,102 @@ +package cmd + +import ( + "fmt" + "os" + "path/filepath" + + "github.com/Snider/Borg/pkg/cache" + "github.com/spf13/cobra" +) + +// cacheCmd represents the cache command +var cacheCmd = NewCacheCmd() +var cacheStatsCmd = NewCacheStatsCmd() +var cacheClearCmd = NewCacheClearCmd() + +func init() { + RootCmd.AddCommand(GetCacheCmd()) + GetCacheCmd().AddCommand(GetCacheStatsCmd()) + GetCacheCmd().AddCommand(GetCacheClearCmd()) +} + +func GetCacheCmd() *cobra.Command { + return cacheCmd +} + +func GetCacheStatsCmd() *cobra.Command { + return cacheStatsCmd +} + +func GetCacheClearCmd() *cobra.Command { + return cacheClearCmd +} + +func NewCacheCmd() *cobra.Command { + cacheCmd := &cobra.Command{ + Use: "cache", + Short: "Manage the cache", + Long: `Manage the cache.`, + } + return cacheCmd +} + +func NewCacheStatsCmd() *cobra.Command { + cacheStatsCmd := &cobra.Command{ + Use: "stats", + Short: "Show cache stats", + Long: `Show cache stats.`, + RunE: func(cmd *cobra.Command, args []string) error { + cacheDir, err := GetCacheDir(cmd) + if err != nil { + return err + } + cacheInstance, err := cache.New(cacheDir) + if err != nil { + return fmt.Errorf("failed to create cache: %w", err) + } + + size, err := cacheInstance.Size() + if err != nil { + return fmt.Errorf("failed to get cache size: %w", err) + } + + fmt.Printf("Cache directory: %s\n", cacheInstance.Dir()) + fmt.Printf("Number of entries: %d\n", cacheInstance.NumEntries()) + fmt.Printf("Total size: %d bytes\n", size) + + return nil + }, + } + cacheStatsCmd.Flags().String("cache-dir", "", "Custom cache location") + return cacheStatsCmd +} + +func NewCacheClearCmd() *cobra.Command { + cacheClearCmd := &cobra.Command{ + Use: "clear", + Short: "Clear the cache", + Long: `Clear the cache.`, + RunE: func(cmd *cobra.Command, args []string) error { + cacheDir, err := GetCacheDir(cmd) + if err != nil { + return err + } + cacheInstance, err := cache.New(cacheDir) + if err != nil { + return fmt.Errorf("failed to create cache: %w", err) + } + + err = cacheInstance.Clear() + if err != nil { + return fmt.Errorf("failed to clear cache: %w", err) + } + + fmt.Println("Cache cleared.") + + return nil + }, + } + cacheClearCmd.Flags().String("cache-dir", "", "Custom cache location") + return cacheClearCmd +} diff --git a/cmd/collect_website.go b/cmd/collect_website.go index 3811f32..d13855b 100644 --- a/cmd/collect_website.go +++ b/cmd/collect_website.go @@ -3,8 +3,10 @@ package cmd import ( "fmt" "os" + "path/filepath" "github.com/schollz/progressbar/v3" + "github.com/Snider/Borg/pkg/cache" "github.com/Snider/Borg/pkg/compress" "github.com/Snider/Borg/pkg/tim" "github.com/Snider/Borg/pkg/trix" @@ -51,11 +53,32 @@ func NewCollectWebsiteCmd() *cobra.Command { bar = ui.NewProgressBar(-1, "Crawling website") } - dn, err := website.DownloadAndPackageWebsite(websiteURL, depth, bar) + noCache, _ := cmd.Flags().GetBool("no-cache") + + var cacheInstance *cache.Cache + var err error + if !noCache { + cacheDir, err := GetCacheDir(cmd) + if err != nil { + return err + } + cacheInstance, err = cache.New(cacheDir) + if err != nil { + return fmt.Errorf("failed to create cache: %w", err) + } + } + + dn, err := website.DownloadAndPackageWebsite(websiteURL, depth, bar, cacheInstance) if err != nil { return fmt.Errorf("error downloading and packaging website: %w", err) } + if cacheInstance != nil { + if err := cacheInstance.Close(); err != nil { + return fmt.Errorf("failed to close cache: %w", err) + } + } + var data []byte if format == "tim" { tim, err := tim.FromDataNode(dn) @@ -104,5 +127,7 @@ func NewCollectWebsiteCmd() *cobra.Command { collectWebsiteCmd.PersistentFlags().String("format", "datanode", "Output format (datanode, tim, or trix)") collectWebsiteCmd.PersistentFlags().String("compression", "none", "Compression format (none, gz, or xz)") collectWebsiteCmd.PersistentFlags().String("password", "", "Password for encryption") + collectWebsiteCmd.Flags().Bool("no-cache", false, "Skip cache, re-download everything") + collectWebsiteCmd.Flags().String("cache-dir", "", "Custom cache location") return collectWebsiteCmd } diff --git a/cmd/collect_website_test.go b/cmd/collect_website_test.go index 2c39674..6a97128 100644 --- a/cmd/collect_website_test.go +++ b/cmd/collect_website_test.go @@ -6,6 +6,7 @@ import ( "strings" "testing" + "github.com/Snider/Borg/pkg/cache" "github.com/Snider/Borg/pkg/datanode" "github.com/Snider/Borg/pkg/website" "github.com/schollz/progressbar/v3" @@ -14,7 +15,7 @@ import ( func TestCollectWebsiteCmd_Good(t *testing.T) { // Mock the website downloader oldDownloadAndPackageWebsite := website.DownloadAndPackageWebsite - website.DownloadAndPackageWebsite = func(startURL string, maxDepth int, bar *progressbar.ProgressBar) (*datanode.DataNode, error) { + website.DownloadAndPackageWebsite = func(startURL string, maxDepth int, bar *progressbar.ProgressBar, cache *cache.Cache) (*datanode.DataNode, error) { return datanode.New(), nil } defer func() { @@ -35,7 +36,7 @@ func TestCollectWebsiteCmd_Good(t *testing.T) { func TestCollectWebsiteCmd_Bad(t *testing.T) { // Mock the website downloader to return an error oldDownloadAndPackageWebsite := website.DownloadAndPackageWebsite - website.DownloadAndPackageWebsite = func(startURL string, maxDepth int, bar *progressbar.ProgressBar) (*datanode.DataNode, error) { + website.DownloadAndPackageWebsite = func(startURL string, maxDepth int, bar *progressbar.ProgressBar, cache *cache.Cache) (*datanode.DataNode, error) { return nil, fmt.Errorf("website error") } defer func() { diff --git a/cmd/root.go b/cmd/root.go index 9cadb27..b18c3eb 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -2,7 +2,10 @@ package cmd import ( "context" + "fmt" "log/slog" + "os" + "path/filepath" "github.com/spf13/cobra" ) @@ -28,3 +31,16 @@ func Execute(log *slog.Logger) error { RootCmd.SetContext(context.WithValue(context.Background(), "logger", log)) return RootCmd.Execute() } + +func GetCacheDir(cmd *cobra.Command) (string, error) { + cacheDir, _ := cmd.Flags().GetString("cache-dir") + if cacheDir != "" { + return cacheDir, nil + } + + home, err := os.UserHomeDir() + if err != nil { + return "", fmt.Errorf("failed to get user home directory: %w", err) + } + return filepath.Join(home, ".borg", "cache"), nil +} diff --git a/examples/collect_website/main.go b/examples/collect_website/main.go index 2e2f606..10cb4dc 100644 --- a/examples/collect_website/main.go +++ b/examples/collect_website/main.go @@ -11,7 +11,7 @@ func main() { log.Println("Collecting website...") // Download and package the website. - dn, err := website.DownloadAndPackageWebsite("https://example.com", 2, nil) + dn, err := website.DownloadAndPackageWebsite("https://example.com", 2, nil, nil) if err != nil { log.Fatalf("Failed to collect website: %v", err) } diff --git a/pkg/cache/cache.go b/pkg/cache/cache.go new file mode 100644 index 0000000..2d78c2f --- /dev/null +++ b/pkg/cache/cache.go @@ -0,0 +1,158 @@ +package cache + +import ( + "crypto/sha256" + "encoding/json" + "fmt" + "os" + "path/filepath" + "sync" +) + +const ( + indexFileName = "index.json" + storageDirName = "sha256" +) + +// Cache provides a content-addressable storage for web content. +type Cache struct { + dir string + index map[string]string + mutex sync.RWMutex +} + +// New creates a new Cache instance. +func New(dir string) (*Cache, error) { + storageDir := filepath.Join(dir, storageDirName) + if err := os.MkdirAll(storageDir, 0755); err != nil { + return nil, fmt.Errorf("failed to create cache storage directory: %w", err) + } + + cache := &Cache{ + dir: dir, + index: make(map[string]string), + } + + if err := cache.loadIndex(); err != nil { + return nil, fmt.Errorf("failed to load cache index: %w", err) + } + + return cache, nil +} + +// Get retrieves content from the cache for a given URL. +func (c *Cache) Get(url string) ([]byte, bool, error) { + c.mutex.RLock() + hash, ok := c.index[url] + c.mutex.RUnlock() + + if !ok { + return nil, false, nil + } + + path := c.getStoragePath(hash) + data, err := os.ReadFile(path) + if err != nil { + if os.IsNotExist(err) { + return nil, false, nil + } + return nil, false, fmt.Errorf("failed to read from cache: %w", err) + } + + return data, true, nil +} + +// Put adds content to the cache for a given URL. +func (c *Cache) Put(url string, data []byte) error { + hashBytes := sha256.Sum256(data) + hash := fmt.Sprintf("%x", hashBytes) + + path := c.getStoragePath(hash) + if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil { + return fmt.Errorf("failed to create cache directory: %w", err) + } + + if err := os.WriteFile(path, data, 0644); err != nil { + return fmt.Errorf("failed to write to cache: %w", err) + } + + c.mutex.Lock() + c.index[url] = hash + c.mutex.Unlock() + + return nil +} + +// Close saves the index file. +func (c *Cache) Close() error { + return c.saveIndex() +} + +// Clear removes the cache directory. +func (c *Cache) Clear() error { + return os.RemoveAll(c.dir) +} + +// Dir returns the cache directory. +func (c *Cache) Dir() string { + return c.dir +} + +// Size returns the total size of the cache. +func (c *Cache) Size() (int64, error) { + var size int64 + err := filepath.Walk(c.dir, func(_ string, info os.FileInfo, err error) error { + if err != nil { + return err + } + if !info.IsDir() { + size += info.Size() + } + return nil + }) + return size, err +} + +// NumEntries returns the number of entries in the cache. +func (c *Cache) NumEntries() int { + c.mutex.RLock() + defer c.mutex.RUnlock() + return len(c.index) +} + + +func (c *Cache) getStoragePath(hash string) string { + return filepath.Join(c.dir, storageDirName, hash[:2], hash) +} + +func (c *Cache) loadIndex() error { + indexPath := filepath.Join(c.dir, indexFileName) + file, err := os.Open(indexPath) + if err != nil { + if os.IsNotExist(err) { + return nil + } + return err + } + defer file.Close() + + decoder := json.NewDecoder(file) + c.mutex.Lock() + defer c.mutex.Unlock() + return decoder.Decode(&c.index) +} + +func (c *Cache) saveIndex() error { + indexPath := filepath.Join(c.dir, indexFileName) + file, err := os.Create(indexPath) + if err != nil { + return err + } + defer file.Close() + + encoder := json.NewEncoder(file) + encoder.SetIndent("", " ") + c.mutex.Lock() + defer c.mutex.Unlock() + return encoder.Encode(c.index) +} diff --git a/pkg/cache/cache_test.go b/pkg/cache/cache_test.go new file mode 100644 index 0000000..4fa766a --- /dev/null +++ b/pkg/cache/cache_test.go @@ -0,0 +1,87 @@ +package cache + +import ( + "os" + "path/filepath" + "testing" +) + +func TestCache_Good(t *testing.T) { + cacheDir := t.TempDir() + c, err := New(cacheDir) + if err != nil { + t.Fatalf("Failed to create cache: %v", err) + } + + // Test Put and Get + url := "https://example.com" + data := []byte("hello world") + err = c.Put(url, data) + if err != nil { + t.Fatalf("Failed to put data in cache: %v", err) + } + + cachedData, ok, err := c.Get(url) + if err != nil { + t.Fatalf("Failed to get data from cache: %v", err) + } + if !ok { + t.Fatal("Expected data to be in cache, but it wasn't") + } + if string(cachedData) != string(data) { + t.Errorf("Expected data to be '%s', but got '%s'", string(data), string(cachedData)) + } + + // Test persistence + err = c.Close() + if err != nil { + t.Fatalf("Failed to close cache: %v", err) + } + + c2, err := New(cacheDir) + if err != nil { + t.Fatalf("Failed to create new cache: %v", err) + } + cachedData, ok, err = c2.Get(url) + if err != nil { + t.Fatalf("Failed to get data from new cache: %v", err) + } + if !ok { + t.Fatal("Expected data to be in new cache, but it wasn't") + } + if string(cachedData) != string(data) { + t.Errorf("Expected data to be '%s', but got '%s'", string(data), string(cachedData)) + } +} + +func TestCache_Clear(t *testing.T) { + cacheDir := t.TempDir() + c, err := New(cacheDir) + if err != nil { + t.Fatalf("Failed to create cache: %v", err) + } + + // Put some data in the cache + url := "https://example.com" + data := []byte("hello world") + err = c.Put(url, data) + if err != nil { + t.Fatalf("Failed to put data in cache: %v", err) + } + err = c.Close() + if err != nil { + t.Fatalf("Failed to close cache: %v", err) + } + + // Clear the cache + err = c.Clear() + if err != nil { + t.Fatalf("Failed to clear cache: %v", err) + } + + // Check that the cache directory is gone + _, err = os.Stat(cacheDir) + if !os.IsNotExist(err) { + t.Fatal("Expected cache directory to be gone, but it wasn't") + } +} diff --git a/pkg/website/website.go b/pkg/website/website.go index b2bd517..3b73827 100644 --- a/pkg/website/website.go +++ b/pkg/website/website.go @@ -7,6 +7,7 @@ import ( "net/url" "strings" + "github.com/Snider/Borg/pkg/cache" "github.com/Snider/Borg/pkg/datanode" "github.com/schollz/progressbar/v3" @@ -24,32 +25,34 @@ type Downloader struct { progressBar *progressbar.ProgressBar client *http.Client errors []error + cache *cache.Cache } // NewDownloader creates a new Downloader. -func NewDownloader(maxDepth int) *Downloader { - return NewDownloaderWithClient(maxDepth, http.DefaultClient) +func NewDownloader(maxDepth int, cache *cache.Cache) *Downloader { + return NewDownloaderWithClient(maxDepth, http.DefaultClient, cache) } // NewDownloaderWithClient creates a new Downloader with a custom http.Client. -func NewDownloaderWithClient(maxDepth int, client *http.Client) *Downloader { +func NewDownloaderWithClient(maxDepth int, client *http.Client, cache *cache.Cache) *Downloader { return &Downloader{ dn: datanode.New(), visited: make(map[string]bool), maxDepth: maxDepth, client: client, errors: make([]error, 0), + cache: cache, } } // downloadAndPackageWebsite downloads a website and packages it into a DataNode. -func downloadAndPackageWebsite(startURL string, maxDepth int, bar *progressbar.ProgressBar) (*datanode.DataNode, error) { +func downloadAndPackageWebsite(startURL string, maxDepth int, bar *progressbar.ProgressBar, cache *cache.Cache) (*datanode.DataNode, error) { baseURL, err := url.Parse(startURL) if err != nil { return nil, err } - d := NewDownloader(maxDepth) + d := NewDownloader(maxDepth, cache) d.baseURL = baseURL d.progressBar = bar d.crawl(startURL, 0) @@ -69,34 +72,14 @@ func (d *Downloader) crawl(pageURL string, depth int) { if depth > d.maxDepth || d.visited[pageURL] { return } - d.visited[pageURL] = true - if d.progressBar != nil { - d.progressBar.Add(1) - } - resp, err := d.client.Get(pageURL) - if err != nil { - d.errors = append(d.errors, fmt.Errorf("Error getting %s: %w", pageURL, err)) + body, contentType := d.download(pageURL) + if body == nil { return } - defer resp.Body.Close() - - if resp.StatusCode >= 400 { - d.errors = append(d.errors, fmt.Errorf("bad status for %s: %s", pageURL, resp.Status)) - return - } - - body, err := io.ReadAll(resp.Body) - if err != nil { - d.errors = append(d.errors, fmt.Errorf("Error reading body of %s: %w", pageURL, err)) - return - } - - relPath := d.getRelativePath(pageURL) - d.dn.AddData(relPath, body) // Don't try to parse non-html content - if !strings.HasPrefix(resp.Header.Get("Content-Type"), "text/html") { + if !strings.HasPrefix(contentType, "text/html") { return } @@ -136,31 +119,56 @@ func (d *Downloader) downloadAsset(assetURL string) { if d.visited[assetURL] { return } - d.visited[assetURL] = true + d.download(assetURL) +} + +func (d *Downloader) download(pageURL string) ([]byte, string) { + d.visited[pageURL] = true if d.progressBar != nil { d.progressBar.Add(1) } - resp, err := d.client.Get(assetURL) + // Check the cache first + if d.cache != nil { + data, ok, err := d.cache.Get(pageURL) + if err != nil { + d.errors = append(d.errors, fmt.Errorf("Error getting from cache %s: %w", pageURL, err)) + // Don't return, as we can still try to download it + } + if ok { + relPath := d.getRelativePath(pageURL) + d.dn.AddData(relPath, data) + return data, "" // We don't know the content type from the cache + } + } + + resp, err := d.client.Get(pageURL) if err != nil { - d.errors = append(d.errors, fmt.Errorf("Error getting asset %s: %w", assetURL, err)) - return + d.errors = append(d.errors, fmt.Errorf("Error getting %s: %w", pageURL, err)) + return nil, "" } defer resp.Body.Close() if resp.StatusCode >= 400 { - d.errors = append(d.errors, fmt.Errorf("bad status for asset %s: %s", assetURL, resp.Status)) - return + d.errors = append(d.errors, fmt.Errorf("bad status for %s: %s", pageURL, resp.Status)) + return nil, "" } body, err := io.ReadAll(resp.Body) if err != nil { - d.errors = append(d.errors, fmt.Errorf("Error reading body of asset %s: %w", assetURL, err)) - return + d.errors = append(d.errors, fmt.Errorf("Error reading body of %s: %w", pageURL, err)) + return nil, "" } - relPath := d.getRelativePath(assetURL) + relPath := d.getRelativePath(pageURL) d.dn.AddData(relPath, body) + + // Add to cache + if d.cache != nil { + d.cache.Put(pageURL, body) + } + + return body, resp.Header.Get("Content-Type") } func (d *Downloader) getRelativePath(pageURL string) string { diff --git a/pkg/website/website_test.go b/pkg/website/website_test.go index d3685e5..6b0a0f2 100644 --- a/pkg/website/website_test.go +++ b/pkg/website/website_test.go @@ -7,9 +7,11 @@ import ( "net/http" "net/http/httptest" "strings" + "sync/atomic" "testing" "time" + "github.com/Snider/Borg/pkg/cache" "github.com/schollz/progressbar/v3" ) @@ -20,7 +22,7 @@ func TestDownloadAndPackageWebsite_Good(t *testing.T) { defer server.Close() bar := progressbar.NewOptions(1, progressbar.OptionSetWriter(io.Discard)) - dn, err := DownloadAndPackageWebsite(server.URL, 2, bar) + dn, err := DownloadAndPackageWebsite(server.URL, 2, bar, nil) if err != nil { t.Fatalf("DownloadAndPackageWebsite failed: %v", err) } @@ -52,7 +54,7 @@ func TestDownloadAndPackageWebsite_Good(t *testing.T) { func TestDownloadAndPackageWebsite_Bad(t *testing.T) { t.Run("Invalid Start URL", func(t *testing.T) { - _, err := DownloadAndPackageWebsite("http://invalid-url", 1, nil) + _, err := DownloadAndPackageWebsite("http://invalid-url", 1, nil, nil) if err == nil { t.Fatal("Expected an error for an invalid start URL, but got nil") } @@ -63,7 +65,7 @@ func TestDownloadAndPackageWebsite_Bad(t *testing.T) { http.Error(w, "Internal Server Error", http.StatusInternalServerError) })) defer server.Close() - _, err := DownloadAndPackageWebsite(server.URL, 1, nil) + _, err := DownloadAndPackageWebsite(server.URL, 1, nil, nil) if err == nil { t.Fatal("Expected an error for a server error on the start URL, but got nil") } @@ -80,7 +82,7 @@ func TestDownloadAndPackageWebsite_Bad(t *testing.T) { })) defer server.Close() // We expect an error because the link is broken. - dn, err := DownloadAndPackageWebsite(server.URL, 1, nil) + dn, err := DownloadAndPackageWebsite(server.URL, 1, nil, nil) if err == nil { t.Fatal("Expected an error for a broken link, but got nil") } @@ -99,7 +101,7 @@ func TestDownloadAndPackageWebsite_Ugly(t *testing.T) { defer server.Close() bar := progressbar.NewOptions(1, progressbar.OptionSetWriter(io.Discard)) - dn, err := DownloadAndPackageWebsite(server.URL, 1, bar) // Max depth of 1 + dn, err := DownloadAndPackageWebsite(server.URL, 1, bar, nil) // Max depth of 1 if err != nil { t.Fatalf("DownloadAndPackageWebsite failed: %v", err) } @@ -122,7 +124,7 @@ func TestDownloadAndPackageWebsite_Ugly(t *testing.T) { fmt.Fprint(w, `External`) })) defer server.Close() - dn, err := DownloadAndPackageWebsite(server.URL, 1, nil) + dn, err := DownloadAndPackageWebsite(server.URL, 1, nil, nil) if err != nil { t.Fatalf("DownloadAndPackageWebsite failed: %v", err) } @@ -156,7 +158,7 @@ func TestDownloadAndPackageWebsite_Ugly(t *testing.T) { // For now, we'll just test that it doesn't hang forever. done := make(chan bool) go func() { - _, err := DownloadAndPackageWebsite(server.URL, 1, nil) + _, err := DownloadAndPackageWebsite(server.URL, 1, nil, nil) if err != nil && !strings.Contains(err.Error(), "context deadline exceeded") { // We expect a timeout error, but other errors are failures. t.Errorf("unexpected error: %v", err) @@ -174,6 +176,60 @@ func TestDownloadAndPackageWebsite_Ugly(t *testing.T) { // --- Helpers --- +func TestDownloadAndPackageWebsite_Cache(t *testing.T) { + var requestCount int32 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt32(&requestCount, 1) + switch r.URL.Path { + case "/": + w.Header().Set("Content-Type", "text/html") + fmt.Fprint(w, `Page 2`) + case "/page2.html": + w.Header().Set("Content-Type", "text/html") + fmt.Fprint(w, `

Page 2

`) + default: + http.NotFound(w, r) + } + })) + defer server.Close() + + cacheDir := t.TempDir() + c, err := cache.New(cacheDir) + if err != nil { + t.Fatalf("Failed to create cache: %v", err) + } + + // First download + _, err = DownloadAndPackageWebsite(server.URL, 2, nil, c) + if err != nil { + t.Fatalf("DownloadAndPackageWebsite failed: %v", err) + } + if err := c.Close(); err != nil { + t.Fatalf("Failed to close cache: %v", err) + } + + if atomic.LoadInt32(&requestCount) != 2 { + t.Errorf("Expected 2 requests to the server, but got %d", requestCount) + } + + // Second download + c2, err := cache.New(cacheDir) + if err != nil { + t.Fatalf("Failed to create new cache: %v", err) + } + _, err = DownloadAndPackageWebsite(server.URL, 2, nil, c2) + if err != nil { + t.Fatalf("DownloadAndPackageWebsite failed: %v", err) + } + if err := c2.Close(); err != nil { + t.Fatalf("Failed to close cache: %v", err) + } + + if atomic.LoadInt32(&requestCount) != 2 { + t.Errorf("Expected 2 requests to the server after caching, but got %d", requestCount) + } +} + func newWebsiteTestServer() *httptest.Server { return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch r.URL.Path {