refactor: Improve test suite and mocking implementation
This commit refactors the test suite and mocking implementation based on user feedback. - The mock HTTP client in `pkg/mocks/http.go` has been made thread-safe. - The `pkg/website/website.go` package has been refactored to use dependency injection for the HTTP client. - The tests in `TDD/collect_commands_test.go` have been updated to use separate buffers for stdout and stderr, and to use a local test server for the `collect website` command.
This commit is contained in:
parent
a8dc55aba5
commit
19431ed23d
3 changed files with 53 additions and 46 deletions
|
|
@ -3,7 +3,10 @@ package tdd_test
|
|||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/Snider/Borg/cmd"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
|
@ -11,39 +14,51 @@ import (
|
|||
func TestCollectCommands(t *testing.T) {
|
||||
t.Setenv("BORG_PLEXSUS", "0")
|
||||
|
||||
// Setup a test server for the website test
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/page2.html" {
|
||||
fmt.Fprintln(w, `<html><body>Hello</body></html>`)
|
||||
} else {
|
||||
fmt.Fprintln(w, `<html><body><a href="/page2.html">Page 2</a></body></html>`)
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
args []string
|
||||
expected string
|
||||
name string
|
||||
args []string
|
||||
expectedStdout string
|
||||
expectedStderr string
|
||||
}{
|
||||
{
|
||||
name: "collect github repos",
|
||||
args: []string{"collect", "github", "repos", "test"},
|
||||
expected: "https://github.com/test/repo1.git",
|
||||
name: "collect github repos",
|
||||
args: []string{"collect", "github", "repos", "test"},
|
||||
expectedStdout: "https://github.com/test/repo1.git",
|
||||
},
|
||||
{
|
||||
name: "collect github repo",
|
||||
args: []string{"collect", "github", "repo", "https://github.com/test/repo1.git"},
|
||||
expected: "Repository saved to repo.datanode",
|
||||
name: "collect github repo",
|
||||
args: []string{"collect", "github", "repo", "https://github.com/test/repo1.git"},
|
||||
expectedStdout: "Repository saved to repo.datanode",
|
||||
},
|
||||
{
|
||||
name: "collect pwa",
|
||||
args: []string{"collect", "pwa", "--uri", "http://test.com"},
|
||||
expected: "PWA saved to pwa.datanode",
|
||||
name: "collect pwa",
|
||||
args: []string{"collect", "pwa", "--uri", "http://test.com"},
|
||||
expectedStdout: "PWA saved to pwa.datanode",
|
||||
},
|
||||
{
|
||||
name: "collect website",
|
||||
args: []string{"collect", "website", "http://test.com"},
|
||||
expected: "Website saved to website.datanode",
|
||||
name: "collect website",
|
||||
args: []string{"collect", "website", server.URL},
|
||||
expectedStdout: "Website saved to website.datanode",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
rootCmd := cmd.NewRootCmd()
|
||||
b := new(bytes.Buffer)
|
||||
rootCmd.SetOut(b)
|
||||
rootCmd.SetErr(b)
|
||||
outBuf := new(bytes.Buffer)
|
||||
errBuf := new(bytes.Buffer)
|
||||
rootCmd.SetOut(outBuf)
|
||||
rootCmd.SetErr(errBuf)
|
||||
rootCmd.SetArgs(tc.args)
|
||||
|
||||
err := rootCmd.ExecuteContext(context.Background())
|
||||
|
|
@ -51,8 +66,11 @@ func TestCollectCommands(t *testing.T) {
|
|||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !strings.Contains(b.String(), tc.expected) {
|
||||
t.Errorf("expected output to contain %q, but got %q", tc.expected, b.String())
|
||||
if tc.expectedStdout != "" && !strings.Contains(outBuf.String(), tc.expectedStdout) {
|
||||
t.Errorf("expected stdout to contain %q, but got %q", tc.expectedStdout, outBuf.String())
|
||||
}
|
||||
if tc.expectedStderr != "" && !strings.Contains(errBuf.String(), tc.expectedStderr) {
|
||||
t.Errorf("expected stderr to contain %q, but got %q", tc.expectedStderr, errBuf.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,17 +4,23 @@ import (
|
|||
"bytes"
|
||||
"io"
|
||||
"net/http"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// MockRoundTripper is a mock implementation of http.RoundTripper.
|
||||
type MockRoundTripper struct {
|
||||
Responses map[string]*http.Response
|
||||
mu sync.RWMutex
|
||||
responses map[string]*http.Response
|
||||
}
|
||||
|
||||
// RoundTrip implements the http.RoundTripper interface.
|
||||
func (m *MockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
url := req.URL.String()
|
||||
if resp, ok := m.Responses[url]; ok {
|
||||
m.mu.RLock()
|
||||
resp, ok := m.responses[url]
|
||||
m.mu.RUnlock()
|
||||
|
||||
if ok {
|
||||
// Read the original body
|
||||
bodyBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
|
|
@ -55,7 +61,7 @@ func (m *MockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error)
|
|||
func NewMockClient(responses map[string]*http.Response) *http.Client {
|
||||
return &http.Client{
|
||||
Transport: &MockRoundTripper{
|
||||
Responses: responses,
|
||||
responses: responses,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,13 +1,10 @@
|
|||
package website
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"github.com/Snider/Borg/pkg/mocks"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/Snider/Borg/pkg/datanode"
|
||||
|
|
@ -16,25 +13,6 @@ import (
|
|||
"golang.org/x/net/html"
|
||||
)
|
||||
|
||||
var getHTTPClient = func() *http.Client {
|
||||
if os.Getenv("BORG_PLEXSUS") == "0" {
|
||||
responses := map[string]*http.Response{
|
||||
"http://test.com": {
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(bytes.NewBufferString(`<html><body><a href="page2.html">Page 2</a></body></html>`)),
|
||||
Header: make(http.Header),
|
||||
},
|
||||
"http://test.com/page2.html": {
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(bytes.NewBufferString(`<html><body>Hello</body></html>`)),
|
||||
Header: make(http.Header),
|
||||
},
|
||||
}
|
||||
return mocks.NewMockClient(responses)
|
||||
}
|
||||
return http.DefaultClient
|
||||
}
|
||||
|
||||
// Downloader is a recursive website downloader.
|
||||
type Downloader struct {
|
||||
baseURL *url.URL
|
||||
|
|
@ -47,11 +25,16 @@ type Downloader struct {
|
|||
|
||||
// NewDownloader creates a new Downloader.
|
||||
func NewDownloader(maxDepth int) *Downloader {
|
||||
return NewDownloaderWithClient(maxDepth, http.DefaultClient)
|
||||
}
|
||||
|
||||
// NewDownloaderWithClient creates a new Downloader with a custom http.Client.
|
||||
func NewDownloaderWithClient(maxDepth int, client *http.Client) *Downloader {
|
||||
return &Downloader{
|
||||
dn: datanode.New(),
|
||||
visited: make(map[string]bool),
|
||||
maxDepth: maxDepth,
|
||||
client: getHTTPClient(),
|
||||
client: client,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue