diff --git a/cmd/collect_website.go b/cmd/collect_website.go index 3811f32..4016bca 100644 --- a/cmd/collect_website.go +++ b/cmd/collect_website.go @@ -38,6 +38,9 @@ func NewCollectWebsiteCmd() *cobra.Command { format, _ := cmd.Flags().GetString("format") compression, _ := cmd.Flags().GetString("compression") password, _ := cmd.Flags().GetString("password") + useSitemap, _ := cmd.Flags().GetBool("use-sitemap") + sitemapOnly, _ := cmd.Flags().GetBool("sitemap-only") + sitemapURL, _ := cmd.Flags().GetString("sitemap") if format != "datanode" && format != "tim" && format != "trix" { return fmt.Errorf("invalid format: %s (must be 'datanode', 'tim', or 'trix')", format) @@ -51,7 +54,7 @@ func NewCollectWebsiteCmd() *cobra.Command { bar = ui.NewProgressBar(-1, "Crawling website") } - dn, err := website.DownloadAndPackageWebsite(websiteURL, depth, bar) + dn, err := website.DownloadAndPackageWebsite(websiteURL, depth, useSitemap, sitemapOnly, sitemapURL, bar) if err != nil { return fmt.Errorf("error downloading and packaging website: %w", err) } @@ -104,5 +107,8 @@ func NewCollectWebsiteCmd() *cobra.Command { collectWebsiteCmd.PersistentFlags().String("format", "datanode", "Output format (datanode, tim, or trix)") collectWebsiteCmd.PersistentFlags().String("compression", "none", "Compression format (none, gz, or xz)") collectWebsiteCmd.PersistentFlags().String("password", "", "Password for encryption") + collectWebsiteCmd.Flags().Bool("use-sitemap", false, "Auto-detect and use sitemap") + collectWebsiteCmd.Flags().Bool("sitemap-only", false, "Collect only sitemap URLs (no crawling)") + collectWebsiteCmd.Flags().String("sitemap", "", "Explicit sitemap URL") return collectWebsiteCmd } diff --git a/cmd/collect_website_test.go b/cmd/collect_website_test.go index 2c39674..a9dc415 100644 --- a/cmd/collect_website_test.go +++ b/cmd/collect_website_test.go @@ -14,7 +14,7 @@ import ( func TestCollectWebsiteCmd_Good(t *testing.T) { // Mock the website downloader oldDownloadAndPackageWebsite := website.DownloadAndPackageWebsite - website.DownloadAndPackageWebsite = func(startURL string, maxDepth int, bar *progressbar.ProgressBar) (*datanode.DataNode, error) { + website.DownloadAndPackageWebsite = func(startURL string, maxDepth int, useSitemap, sitemapOnly bool, sitemapURL string, bar *progressbar.ProgressBar) (*datanode.DataNode, error) { return datanode.New(), nil } defer func() { @@ -32,10 +32,76 @@ func TestCollectWebsiteCmd_Good(t *testing.T) { } } +func TestCollectWebsiteCmd_Sitemap_Good(t *testing.T) { + var capturedUseSitemap, capturedSitemapOnly bool + var capturedSitemapURL string + + oldDownloadAndPackageWebsite := website.DownloadAndPackageWebsite + website.DownloadAndPackageWebsite = func(startURL string, maxDepth int, useSitemap, sitemapOnly bool, sitemapURL string, bar *progressbar.ProgressBar) (*datanode.DataNode, error) { + capturedUseSitemap = useSitemap + capturedSitemapOnly = sitemapOnly + capturedSitemapURL = sitemapURL + return datanode.New(), nil + } + defer func() { + website.DownloadAndPackageWebsite = oldDownloadAndPackageWebsite + }() + + testCases := []struct { + name string + args []string + expectedUseSitemap bool + expectedSitemapOnly bool + expectedSitemapURL string + }{ + { + name: "use-sitemap flag", + args: []string{"https://example.com", "--use-sitemap"}, + expectedUseSitemap: true, + expectedSitemapOnly: false, + expectedSitemapURL: "", + }, + { + name: "sitemap-only flag", + args: []string{"https://example.com", "--sitemap-only"}, + expectedUseSitemap: false, + expectedSitemapOnly: true, + expectedSitemapURL: "", + }, + { + name: "sitemap flag", + args: []string{"https://example.com", "--sitemap", "https://example.com/sitemap.xml"}, + expectedUseSitemap: false, + expectedSitemapOnly: false, + expectedSitemapURL: "https://example.com/sitemap.xml", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + rootCmd := NewCollectWebsiteCmd() + _, err := executeCommand(rootCmd, tc.args...) + if err != nil { + t.Fatalf("command execution failed: %v", err) + } + + if capturedUseSitemap != tc.expectedUseSitemap { + t.Errorf("expected useSitemap to be %v, but got %v", tc.expectedUseSitemap, capturedUseSitemap) + } + if capturedSitemapOnly != tc.expectedSitemapOnly { + t.Errorf("expected sitemapOnly to be %v, but got %v", tc.expectedSitemapOnly, capturedSitemapOnly) + } + if capturedSitemapURL != tc.expectedSitemapURL { + t.Errorf("expected sitemapURL to be %q, but got %q", tc.expectedSitemapURL, capturedSitemapURL) + } + }) + } +} + func TestCollectWebsiteCmd_Bad(t *testing.T) { // Mock the website downloader to return an error oldDownloadAndPackageWebsite := website.DownloadAndPackageWebsite - website.DownloadAndPackageWebsite = func(startURL string, maxDepth int, bar *progressbar.ProgressBar) (*datanode.DataNode, error) { + website.DownloadAndPackageWebsite = func(startURL string, maxDepth int, useSitemap, sitemapOnly bool, sitemapURL string, bar *progressbar.ProgressBar) (*datanode.DataNode, error) { return nil, fmt.Errorf("website error") } defer func() { diff --git a/examples/collect_website/main.go b/examples/collect_website/main.go index 2e2f606..bdd96b9 100644 --- a/examples/collect_website/main.go +++ b/examples/collect_website/main.go @@ -11,7 +11,7 @@ func main() { log.Println("Collecting website...") // Download and package the website. - dn, err := website.DownloadAndPackageWebsite("https://example.com", 2, nil) + dn, err := website.DownloadAndPackageWebsite("https://example.com", 2, false, false, "", nil) if err != nil { log.Fatalf("Failed to collect website: %v", err) } diff --git a/pkg/website/sitemap.go b/pkg/website/sitemap.go new file mode 100644 index 0000000..d59efc8 --- /dev/null +++ b/pkg/website/sitemap.go @@ -0,0 +1,116 @@ +package website + +import ( + "compress/gzip" + "encoding/xml" + "fmt" + "io" + "net/http" + "net/url" + "strings" +) + +// SitemapURL represents a single URL entry in a sitemap. +type SitemapURL struct { + Loc string `xml:"loc"` +} + +// URLSet represents a standard sitemap. +type URLSet struct { + XMLName xml.Name `xml:"urlset"` + URLs []SitemapURL `xml:"url"` +} + +// SitemapIndex represents a sitemap index file. +type SitemapIndex struct { + XMLName xml.Name `xml:"sitemapindex"` + Sitemaps []SitemapURL `xml:"sitemap"` +} + +// FetchAndParseSitemap fetches and parses a sitemap from the given URL. +// It handles standard sitemaps, sitemap indexes, and gzipped sitemaps. +func FetchAndParseSitemap(sitemapURL string, client *http.Client) ([]string, error) { + resp, err := client.Get(sitemapURL) + if err != nil { + return nil, fmt.Errorf("error fetching sitemap: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("bad status for sitemap: %s", resp.Status) + } + + var reader io.Reader = resp.Body + if strings.HasSuffix(sitemapURL, ".gz") { + gzReader, err := gzip.NewReader(resp.Body) + if err != nil { + return nil, fmt.Errorf("error creating gzip reader: %w", err) + } + defer gzReader.Close() + reader = gzReader + } + + body, err := io.ReadAll(reader) + if err != nil { + return nil, fmt.Errorf("error reading sitemap body: %w", err) + } + + // Try parsing as a sitemap index first + var sitemapIndex SitemapIndex + if err := xml.Unmarshal(body, &sitemapIndex); err == nil && len(sitemapIndex.Sitemaps) > 0 { + var allURLs []string + for _, sitemap := range sitemapIndex.Sitemaps { + urls, err := FetchAndParseSitemap(sitemap.Loc, client) + if err != nil { + // In a real-world scenario, you might want to handle this more gracefully + return nil, fmt.Errorf("error parsing sitemap from index '%s': %w", sitemap.Loc, err) + } + allURLs = append(allURLs, urls...) + } + return allURLs, nil + } + + // If not a sitemap index, try parsing as a standard sitemap + var urlSet URLSet + if err := xml.Unmarshal(body, &urlSet); err == nil { + var urls []string + for _, u := range urlSet.URLs { + urls = append(urls, u.Loc) + } + return urls, nil + } + + return nil, fmt.Errorf("failed to parse sitemap XML from %s", sitemapURL) +} + +// DiscoverSitemap attempts to discover the sitemap URL for a given base URL. +func DiscoverSitemap(baseURL string, client *http.Client) (string, error) { + u, err := url.Parse(baseURL) + if err != nil { + return "", err + } + // Make sure we're at the root of the domain + u.Path = "" + u.RawQuery = "" + u.Fragment = "" + + + sitemapPaths := []string{"sitemap.xml", "sitemap_index.xml", "sitemap.xml.gz", "sitemap_index.xml.gz"} + + for _, path := range sitemapPaths { + sitemapURL := u.String() + "/" + path + resp, err := client.Head(sitemapURL) + if err == nil && resp.StatusCode == http.StatusOK { + // Ensure we close the body, even for a HEAD request + io.Copy(io.Discard, resp.Body) + resp.Body.Close() + return sitemapURL, nil + } + if resp != nil { + io.Copy(io.Discard, resp.Body) + resp.Body.Close() + } + } + + return "", fmt.Errorf("sitemap not found for %s", baseURL) +} diff --git a/pkg/website/sitemap_test.go b/pkg/website/sitemap_test.go new file mode 100644 index 0000000..a9f5145 --- /dev/null +++ b/pkg/website/sitemap_test.go @@ -0,0 +1,118 @@ +package website + +import ( + "compress/gzip" + "fmt" + "net/http" + "net/http/httptest" + "reflect" + "testing" +) + +func TestFetchAndParseSitemap_Good(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/sitemap.xml": + fmt.Fprintln(w, ` + + + http://www.example.com/ + + + http://www.example.com/page1 + +`) + case "/sitemap_index.xml": + fmt.Fprintln(w, ` + + + http://`+r.Host+`/sitemap1.xml + + + http://`+r.Host+`/sitemap2.xml + +`) + case "/sitemap1.xml": + fmt.Fprintln(w, ` + + + http://www.example.com/page2 + +`) + case "/sitemap2.xml": + fmt.Fprintln(w, ` + + + http://www.example.com/page3 + +`) + case "/sitemap.xml.gz": + w.Header().Set("Content-Type", "application/gzip") + gz := gzip.NewWriter(w) + defer gz.Close() + gz.Write([]byte(` + + + http://www.example.com/gzipped + +`)) + default: + http.NotFound(w, r) + } + })) + defer server.Close() + + testCases := []struct { + name string + url string + expected []string + }{ + {"Standard Sitemap", server.URL + "/sitemap.xml", []string{"http://www.example.com/", "http://www.example.com/page1"}}, + {"Sitemap Index", server.URL + "/sitemap_index.xml", []string{"http://www.example.com/page2", "http://www.example.com/page3"}}, + {"Gzipped Sitemap", server.URL + "/sitemap.xml.gz", []string{"http://www.example.com/gzipped"}}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + urls, err := FetchAndParseSitemap(tc.url, server.Client()) + if err != nil { + t.Fatalf("FetchAndParseSitemap() error = %v", err) + } + if !reflect.DeepEqual(urls, tc.expected) { + t.Errorf("FetchAndParseSitemap() = %v, want %v", urls, tc.expected) + } + }) + } +} + +func TestDiscoverSitemap_Good(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/sitemap.xml" { + w.WriteHeader(http.StatusOK) + } else { + http.NotFound(w, r) + } + })) + defer server.Close() + + sitemapURL, err := DiscoverSitemap(server.URL, server.Client()) + if err != nil { + t.Fatalf("DiscoverSitemap() error = %v", err) + } + expected := server.URL + "/sitemap.xml" + if sitemapURL != expected { + t.Errorf("DiscoverSitemap() = %v, want %v", sitemapURL, expected) + } +} + +func TestDiscoverSitemap_Bad(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.NotFound(w, r) + })) + defer server.Close() + + _, err := DiscoverSitemap(server.URL, server.Client()) + if err == nil { + t.Error("DiscoverSitemap() error = nil, want error") + } +} diff --git a/pkg/website/website.go b/pkg/website/website.go index b2bd517..5b5002a 100644 --- a/pkg/website/website.go +++ b/pkg/website/website.go @@ -5,6 +5,7 @@ import ( "io" "net/http" "net/url" + "os" "strings" "github.com/Snider/Borg/pkg/datanode" @@ -15,6 +16,66 @@ import ( var DownloadAndPackageWebsite = downloadAndPackageWebsite +func downloadAndPackageWebsite(startURL string, maxDepth int, useSitemap, sitemapOnly bool, sitemapURL string, bar *progressbar.ProgressBar) (*datanode.DataNode, error) { + baseURL, err := url.Parse(startURL) + if err != nil { + return nil, err + } + + d := NewDownloader(maxDepth) + d.baseURL = baseURL + d.progressBar = bar + d.sitemapOnly = sitemapOnly + + var sitemapURLs []string + var sitemapErr error + + // Determine which sitemap URL to use + actualSitemapURL := sitemapURL + if actualSitemapURL == "" && (useSitemap || sitemapOnly) { + actualSitemapURL, sitemapErr = DiscoverSitemap(startURL, d.client) + if sitemapErr != nil { + if sitemapOnly { + return nil, fmt.Errorf("sitemap discovery failed for %s: %w", startURL, sitemapErr) + } + // For useSitemap, we can warn and proceed with crawling + fmt.Fprintf(os.Stderr, "Warning: sitemap discovery failed for %s: %v. Proceeding with crawl.\n", startURL, sitemapErr) + } + } + + if actualSitemapURL != "" { + sitemapURLs, sitemapErr = FetchAndParseSitemap(actualSitemapURL, d.client) + if sitemapErr != nil { + if sitemapOnly { + return nil, fmt.Errorf("sitemap parsing failed for %s: %w", actualSitemapURL, sitemapErr) + } + fmt.Fprintf(os.Stderr, "Warning: sitemap parsing failed for %s: %v. Proceeding with crawl.\n", actualSitemapURL, sitemapErr) + } + } + + // Process URLs from sitemap + if len(sitemapURLs) > 0 { + for _, u := range sitemapURLs { + d.crawl(u, 0) + } + } + + // Crawl from start URL if not in sitemap-only mode + if !sitemapOnly { + d.crawl(startURL, 0) + } + + if len(d.errors) > 0 { + var errs []string + for _, e := range d.errors { + errs = append(errs, e.Error()) + } + return nil, fmt.Errorf("failed to download website:\n%s", strings.Join(errs, "\n")) + } + + return d.dn, nil +} + // Downloader is a recursive website downloader. type Downloader struct { baseURL *url.URL @@ -22,6 +83,7 @@ type Downloader struct { visited map[string]bool maxDepth int progressBar *progressbar.ProgressBar + sitemapOnly bool client *http.Client errors []error } @@ -37,33 +99,12 @@ func NewDownloaderWithClient(maxDepth int, client *http.Client) *Downloader { dn: datanode.New(), visited: make(map[string]bool), maxDepth: maxDepth, + sitemapOnly: false, client: client, errors: make([]error, 0), } } -// downloadAndPackageWebsite downloads a website and packages it into a DataNode. -func downloadAndPackageWebsite(startURL string, maxDepth int, bar *progressbar.ProgressBar) (*datanode.DataNode, error) { - baseURL, err := url.Parse(startURL) - if err != nil { - return nil, err - } - - d := NewDownloader(maxDepth) - d.baseURL = baseURL - d.progressBar = bar - d.crawl(startURL, 0) - - if len(d.errors) > 0 { - var errs []string - for _, e := range d.errors { - errs = append(errs, e.Error()) - } - return nil, fmt.Errorf("failed to download website:\n%s", strings.Join(errs, "\n")) - } - - return d.dn, nil -} func (d *Downloader) crawl(pageURL string, depth int) { if depth > d.maxDepth || d.visited[pageURL] { @@ -118,7 +159,7 @@ func (d *Downloader) crawl(pageURL string, depth int) { if d.isLocal(link) { if isAsset(link) { d.downloadAsset(link) - } else { + } else if !d.sitemapOnly { d.crawl(link, depth+1) } } diff --git a/pkg/website/website_test.go b/pkg/website/website_test.go index d3685e5..1c433c7 100644 --- a/pkg/website/website_test.go +++ b/pkg/website/website_test.go @@ -20,7 +20,7 @@ func TestDownloadAndPackageWebsite_Good(t *testing.T) { defer server.Close() bar := progressbar.NewOptions(1, progressbar.OptionSetWriter(io.Discard)) - dn, err := DownloadAndPackageWebsite(server.URL, 2, bar) + dn, err := DownloadAndPackageWebsite(server.URL, 2, false, false, "", bar) if err != nil { t.Fatalf("DownloadAndPackageWebsite failed: %v", err) } @@ -52,7 +52,7 @@ func TestDownloadAndPackageWebsite_Good(t *testing.T) { func TestDownloadAndPackageWebsite_Bad(t *testing.T) { t.Run("Invalid Start URL", func(t *testing.T) { - _, err := DownloadAndPackageWebsite("http://invalid-url", 1, nil) + _, err := DownloadAndPackageWebsite("http://invalid-url", 1, false, false, "", nil) if err == nil { t.Fatal("Expected an error for an invalid start URL, but got nil") } @@ -63,7 +63,7 @@ func TestDownloadAndPackageWebsite_Bad(t *testing.T) { http.Error(w, "Internal Server Error", http.StatusInternalServerError) })) defer server.Close() - _, err := DownloadAndPackageWebsite(server.URL, 1, nil) + _, err := DownloadAndPackageWebsite(server.URL, 1, false, false, "", nil) if err == nil { t.Fatal("Expected an error for a server error on the start URL, but got nil") } @@ -80,7 +80,7 @@ func TestDownloadAndPackageWebsite_Bad(t *testing.T) { })) defer server.Close() // We expect an error because the link is broken. - dn, err := DownloadAndPackageWebsite(server.URL, 1, nil) + dn, err := DownloadAndPackageWebsite(server.URL, 1, false, false, "", nil) if err == nil { t.Fatal("Expected an error for a broken link, but got nil") } @@ -99,7 +99,7 @@ func TestDownloadAndPackageWebsite_Ugly(t *testing.T) { defer server.Close() bar := progressbar.NewOptions(1, progressbar.OptionSetWriter(io.Discard)) - dn, err := DownloadAndPackageWebsite(server.URL, 1, bar) // Max depth of 1 + dn, err := DownloadAndPackageWebsite(server.URL, 1, false, false, "", bar) // Max depth of 1 if err != nil { t.Fatalf("DownloadAndPackageWebsite failed: %v", err) } @@ -122,7 +122,7 @@ func TestDownloadAndPackageWebsite_Ugly(t *testing.T) { fmt.Fprint(w, `External`) })) defer server.Close() - dn, err := DownloadAndPackageWebsite(server.URL, 1, nil) + dn, err := DownloadAndPackageWebsite(server.URL, 1, false, false, "", nil) if err != nil { t.Fatalf("DownloadAndPackageWebsite failed: %v", err) } @@ -156,7 +156,7 @@ func TestDownloadAndPackageWebsite_Ugly(t *testing.T) { // For now, we'll just test that it doesn't hang forever. done := make(chan bool) go func() { - _, err := DownloadAndPackageWebsite(server.URL, 1, nil) + _, err := DownloadAndPackageWebsite(server.URL, 1, false, false, "", nil) if err != nil && !strings.Contains(err.Error(), "context deadline exceeded") { // We expect a timeout error, but other errors are failures. t.Errorf("unexpected error: %v", err) @@ -172,6 +172,63 @@ func TestDownloadAndPackageWebsite_Ugly(t *testing.T) { }) } +func TestDownloadAndPackageWebsite_Sitemap(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/sitemap.xml": + fmt.Fprintln(w, ` + + + http://`+r.Host+`/page1 + +`) + case "/": + w.Header().Set("Content-Type", "text/html") + fmt.Fprint(w, ` + + + Page 2 + + `) + case "/page1": + w.Header().Set("Content-Type", "text/html") + fmt.Fprint(w, `

Page 1 from Sitemap

`) + case "/page2": + w.Header().Set("Content-Type", "text/html") + fmt.Fprint(w, `

Page 2 from Crawl

`) + default: + http.NotFound(w, r) + } + })) + defer server.Close() + + t.Run("sitemap-only", func(t *testing.T) { + dn, err := DownloadAndPackageWebsite(server.URL, 2, false, true, server.URL+"/sitemap.xml", nil) + if err != nil { + t.Fatalf("DownloadAndPackageWebsite with sitemap-only failed: %v", err) + } + if exists, _ := dn.Exists("page1"); !exists { + t.Error("Expected to find /page1 from sitemap, but it was not found") + } + if exists, _ := dn.Exists("page2"); exists { + t.Error("Did not expect to find /page2 from crawl, but it was found") + } + }) + + t.Run("use-sitemap", func(t *testing.T) { + dn, err := DownloadAndPackageWebsite(server.URL, 2, true, false, server.URL+"/sitemap.xml", nil) + if err != nil { + t.Fatalf("DownloadAndPackageWebsite with use-sitemap failed: %v", err) + } + if exists, _ := dn.Exists("page1"); !exists { + t.Error("Expected to find /page1 from sitemap, but it was not found") + } + if exists, _ := dn.Exists("page2"); !exists { + t.Error("Expected to find /page2 from crawl, but it was not found") + } + }) +} + // --- Helpers --- func newWebsiteTestServer() *httptest.Server {