refactor: Improve test suite and mocking implementation

This commit refactors the test suite and mocking implementation based on user feedback.

- The mock HTTP client in `pkg/mocks/http.go` has been made thread-safe.
- The `pkg/website/website.go` package has been refactored to use dependency injection for the HTTP client.
- The tests in `TDD/collect_commands_test.go` have been updated to use separate buffers for stdout and stderr, and to use a local test server for the `collect website` command.
This commit is contained in:
google-labs-jules[bot] 2025-11-02 20:03:32 +00:00
parent a8dc55aba5
commit 19431ed23d
3 changed files with 53 additions and 46 deletions

View file

@ -3,7 +3,10 @@ package tdd_test
import (
"bytes"
"context"
"fmt"
"github.com/Snider/Borg/cmd"
"net/http"
"net/http/httptest"
"strings"
"testing"
)
@ -11,39 +14,51 @@ import (
func TestCollectCommands(t *testing.T) {
t.Setenv("BORG_PLEXSUS", "0")
// Setup a test server for the website test
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/page2.html" {
fmt.Fprintln(w, `<html><body>Hello</body></html>`)
} else {
fmt.Fprintln(w, `<html><body><a href="/page2.html">Page 2</a></body></html>`)
}
}))
defer server.Close()
testCases := []struct {
name string
args []string
expected string
name string
args []string
expectedStdout string
expectedStderr string
}{
{
name: "collect github repos",
args: []string{"collect", "github", "repos", "test"},
expected: "https://github.com/test/repo1.git",
name: "collect github repos",
args: []string{"collect", "github", "repos", "test"},
expectedStdout: "https://github.com/test/repo1.git",
},
{
name: "collect github repo",
args: []string{"collect", "github", "repo", "https://github.com/test/repo1.git"},
expected: "Repository saved to repo.datanode",
name: "collect github repo",
args: []string{"collect", "github", "repo", "https://github.com/test/repo1.git"},
expectedStdout: "Repository saved to repo.datanode",
},
{
name: "collect pwa",
args: []string{"collect", "pwa", "--uri", "http://test.com"},
expected: "PWA saved to pwa.datanode",
name: "collect pwa",
args: []string{"collect", "pwa", "--uri", "http://test.com"},
expectedStdout: "PWA saved to pwa.datanode",
},
{
name: "collect website",
args: []string{"collect", "website", "http://test.com"},
expected: "Website saved to website.datanode",
name: "collect website",
args: []string{"collect", "website", server.URL},
expectedStdout: "Website saved to website.datanode",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
rootCmd := cmd.NewRootCmd()
b := new(bytes.Buffer)
rootCmd.SetOut(b)
rootCmd.SetErr(b)
outBuf := new(bytes.Buffer)
errBuf := new(bytes.Buffer)
rootCmd.SetOut(outBuf)
rootCmd.SetErr(errBuf)
rootCmd.SetArgs(tc.args)
err := rootCmd.ExecuteContext(context.Background())
@ -51,8 +66,11 @@ func TestCollectCommands(t *testing.T) {
t.Fatal(err)
}
if !strings.Contains(b.String(), tc.expected) {
t.Errorf("expected output to contain %q, but got %q", tc.expected, b.String())
if tc.expectedStdout != "" && !strings.Contains(outBuf.String(), tc.expectedStdout) {
t.Errorf("expected stdout to contain %q, but got %q", tc.expectedStdout, outBuf.String())
}
if tc.expectedStderr != "" && !strings.Contains(errBuf.String(), tc.expectedStderr) {
t.Errorf("expected stderr to contain %q, but got %q", tc.expectedStderr, errBuf.String())
}
})
}

View file

@ -4,17 +4,23 @@ import (
"bytes"
"io"
"net/http"
"sync"
)
// MockRoundTripper is a mock implementation of http.RoundTripper.
type MockRoundTripper struct {
Responses map[string]*http.Response
mu sync.RWMutex
responses map[string]*http.Response
}
// RoundTrip implements the http.RoundTripper interface.
func (m *MockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
url := req.URL.String()
if resp, ok := m.Responses[url]; ok {
m.mu.RLock()
resp, ok := m.responses[url]
m.mu.RUnlock()
if ok {
// Read the original body
bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
@ -55,7 +61,7 @@ func (m *MockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error)
func NewMockClient(responses map[string]*http.Response) *http.Client {
return &http.Client{
Transport: &MockRoundTripper{
Responses: responses,
responses: responses,
},
}
}

View file

@ -1,13 +1,10 @@
package website
import (
"bytes"
"fmt"
"github.com/Snider/Borg/pkg/mocks"
"io"
"net/http"
"net/url"
"os"
"strings"
"github.com/Snider/Borg/pkg/datanode"
@ -16,25 +13,6 @@ import (
"golang.org/x/net/html"
)
var getHTTPClient = func() *http.Client {
if os.Getenv("BORG_PLEXSUS") == "0" {
responses := map[string]*http.Response{
"http://test.com": {
StatusCode: http.StatusOK,
Body: io.NopCloser(bytes.NewBufferString(`<html><body><a href="page2.html">Page 2</a></body></html>`)),
Header: make(http.Header),
},
"http://test.com/page2.html": {
StatusCode: http.StatusOK,
Body: io.NopCloser(bytes.NewBufferString(`<html><body>Hello</body></html>`)),
Header: make(http.Header),
},
}
return mocks.NewMockClient(responses)
}
return http.DefaultClient
}
// Downloader is a recursive website downloader.
type Downloader struct {
baseURL *url.URL
@ -47,11 +25,16 @@ type Downloader struct {
// NewDownloader creates a new Downloader.
func NewDownloader(maxDepth int) *Downloader {
return NewDownloaderWithClient(maxDepth, http.DefaultClient)
}
// NewDownloaderWithClient creates a new Downloader with a custom http.Client.
func NewDownloaderWithClient(maxDepth int, client *http.Client) *Downloader {
return &Downloader{
dn: datanode.New(),
visited: make(map[string]bool),
maxDepth: maxDepth,
client: getHTTPClient(),
client: client,
}
}