From 19431ed23d28a5594402f4722b038f2cfcfc5ba8 Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Sun, 2 Nov 2025 20:03:32 +0000 Subject: [PATCH] refactor: Improve test suite and mocking implementation This commit refactors the test suite and mocking implementation based on user feedback. - The mock HTTP client in `pkg/mocks/http.go` has been made thread-safe. - The `pkg/website/website.go` package has been refactored to use dependency injection for the HTTP client. - The tests in `TDD/collect_commands_test.go` have been updated to use separate buffers for stdout and stderr, and to use a local test server for the `collect website` command. --- TDD/collect_commands_test.go | 58 +++++++++++++++++++++++------------- pkg/mocks/http.go | 12 ++++++-- pkg/website/website.go | 29 ++++-------------- 3 files changed, 53 insertions(+), 46 deletions(-) diff --git a/TDD/collect_commands_test.go b/TDD/collect_commands_test.go index 156408f..84740b1 100644 --- a/TDD/collect_commands_test.go +++ b/TDD/collect_commands_test.go @@ -3,7 +3,10 @@ package tdd_test import ( "bytes" "context" + "fmt" "github.com/Snider/Borg/cmd" + "net/http" + "net/http/httptest" "strings" "testing" ) @@ -11,39 +14,51 @@ import ( func TestCollectCommands(t *testing.T) { t.Setenv("BORG_PLEXSUS", "0") + // Setup a test server for the website test + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/page2.html" { + fmt.Fprintln(w, `
Hello`) + } else { + fmt.Fprintln(w, `Page 2`) + } + })) + defer server.Close() + testCases := []struct { - name string - args []string - expected string + name string + args []string + expectedStdout string + expectedStderr string }{ { - name: "collect github repos", - args: []string{"collect", "github", "repos", "test"}, - expected: "https://github.com/test/repo1.git", + name: "collect github repos", + args: []string{"collect", "github", "repos", "test"}, + expectedStdout: "https://github.com/test/repo1.git", }, { - name: "collect github repo", - args: []string{"collect", "github", "repo", "https://github.com/test/repo1.git"}, - expected: "Repository saved to repo.datanode", + name: "collect github repo", + args: []string{"collect", "github", "repo", "https://github.com/test/repo1.git"}, + expectedStdout: "Repository saved to repo.datanode", }, { - name: "collect pwa", - args: []string{"collect", "pwa", "--uri", "http://test.com"}, - expected: "PWA saved to pwa.datanode", + name: "collect pwa", + args: []string{"collect", "pwa", "--uri", "http://test.com"}, + expectedStdout: "PWA saved to pwa.datanode", }, { - name: "collect website", - args: []string{"collect", "website", "http://test.com"}, - expected: "Website saved to website.datanode", + name: "collect website", + args: []string{"collect", "website", server.URL}, + expectedStdout: "Website saved to website.datanode", }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { rootCmd := cmd.NewRootCmd() - b := new(bytes.Buffer) - rootCmd.SetOut(b) - rootCmd.SetErr(b) + outBuf := new(bytes.Buffer) + errBuf := new(bytes.Buffer) + rootCmd.SetOut(outBuf) + rootCmd.SetErr(errBuf) rootCmd.SetArgs(tc.args) err := rootCmd.ExecuteContext(context.Background()) @@ -51,8 +66,11 @@ func TestCollectCommands(t *testing.T) { t.Fatal(err) } - if !strings.Contains(b.String(), tc.expected) { - t.Errorf("expected output to contain %q, but got %q", tc.expected, b.String()) + if tc.expectedStdout != "" && !strings.Contains(outBuf.String(), tc.expectedStdout) { + t.Errorf("expected stdout to contain %q, but got %q", tc.expectedStdout, outBuf.String()) + } + if tc.expectedStderr != "" && !strings.Contains(errBuf.String(), tc.expectedStderr) { + t.Errorf("expected stderr to contain %q, but got %q", tc.expectedStderr, errBuf.String()) } }) } diff --git a/pkg/mocks/http.go b/pkg/mocks/http.go index 29ba838..e8cfabc 100644 --- a/pkg/mocks/http.go +++ b/pkg/mocks/http.go @@ -4,17 +4,23 @@ import ( "bytes" "io" "net/http" + "sync" ) // MockRoundTripper is a mock implementation of http.RoundTripper. type MockRoundTripper struct { - Responses map[string]*http.Response + mu sync.RWMutex + responses map[string]*http.Response } // RoundTrip implements the http.RoundTripper interface. func (m *MockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { url := req.URL.String() - if resp, ok := m.Responses[url]; ok { + m.mu.RLock() + resp, ok := m.responses[url] + m.mu.RUnlock() + + if ok { // Read the original body bodyBytes, err := io.ReadAll(resp.Body) if err != nil { @@ -55,7 +61,7 @@ func (m *MockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) func NewMockClient(responses map[string]*http.Response) *http.Client { return &http.Client{ Transport: &MockRoundTripper{ - Responses: responses, + responses: responses, }, } } diff --git a/pkg/website/website.go b/pkg/website/website.go index e98a6b0..1780c0e 100644 --- a/pkg/website/website.go +++ b/pkg/website/website.go @@ -1,13 +1,10 @@ package website import ( - "bytes" "fmt" - "github.com/Snider/Borg/pkg/mocks" "io" "net/http" "net/url" - "os" "strings" "github.com/Snider/Borg/pkg/datanode" @@ -16,25 +13,6 @@ import ( "golang.org/x/net/html" ) -var getHTTPClient = func() *http.Client { - if os.Getenv("BORG_PLEXSUS") == "0" { - responses := map[string]*http.Response{ - "http://test.com": { - StatusCode: http.StatusOK, - Body: io.NopCloser(bytes.NewBufferString(`Page 2`)), - Header: make(http.Header), - }, - "http://test.com/page2.html": { - StatusCode: http.StatusOK, - Body: io.NopCloser(bytes.NewBufferString(`Hello`)), - Header: make(http.Header), - }, - } - return mocks.NewMockClient(responses) - } - return http.DefaultClient -} - // Downloader is a recursive website downloader. type Downloader struct { baseURL *url.URL @@ -47,11 +25,16 @@ type Downloader struct { // NewDownloader creates a new Downloader. func NewDownloader(maxDepth int) *Downloader { + return NewDownloaderWithClient(maxDepth, http.DefaultClient) +} + +// NewDownloaderWithClient creates a new Downloader with a custom http.Client. +func NewDownloaderWithClient(maxDepth int, client *http.Client) *Downloader { return &Downloader{ dn: datanode.New(), visited: make(map[string]bool), maxDepth: maxDepth, - client: getHTTPClient(), + client: client, } }