From cdb4bdbc451297743bc9d3a7373abebd0524ffab Mon Sep 17 00:00:00 2001 From: Snider Date: Sat, 25 Apr 2026 19:19:36 +0100 Subject: [PATCH] fix(mcp/brain/client): block absolute-URL bypass + apiURL scheme allowlist (MEDIUM) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit requestURL() now returns (string, error) and rejects absolute or host-bearing URLs BEFORE request construction and BEFORE Authorization header is set. Closes the bearer-key leak vector: a path that ever flows from upstream JSON, config, or a tool argument can no longer spray the Bearer token at attacker-chosen URLs. New() validates apiURL at construction: - https://* always accepted - http://* rejected unless CORE_BRAIN_INSECURE=true is set (explicit dev/test opt-in; production should always be TLS) Cerberus #1052 from workspace-wide sniff. Today's call sites (Remember, Recall, Forget, List) hardcode the path → safe; this closes the API shape that invited future Call(ctx, method, untrustedPath, body) patterns from leaking the bearer. Tests: absolute http:// + https:// paths make zero HTTP calls, good relative path construction works, http:// apiURL rejected by default + accepted with CORE_BRAIN_INSECURE=true. Existing test fixtures converted to TLS to match the new default policy. Co-authored-by: Codex Closes tasks.lthn.sh/view.php?id=1052 --- pkg/mcp/brain/client/client.go | 36 +++++++-- pkg/mcp/brain/client/client_test.go | 110 +++++++++++++++++++++++++--- 2 files changed, 129 insertions(+), 17 deletions(-) diff --git a/pkg/mcp/brain/client/client.go b/pkg/mcp/brain/client/client.go index 8ea8165..f35fdae 100644 --- a/pkg/mcp/brain/client/client.go +++ b/pkg/mcp/brain/client/client.go @@ -28,6 +28,7 @@ import ( const ( DefaultURL = "https://api.lthn.sh" + insecureBrainEnv = "CORE_BRAIN_INSECURE" brainKeyFileMode = fs.FileMode(0o600) defaultAgentID = "cladius" defaultTimeout = 30 * time.Second @@ -147,6 +148,7 @@ func New(options Options) *Client { if apiURL == "" { apiURL = DefaultURL } + configErr := validateAPIURL(apiURL) agentID := core.Trim(options.AgentID) if agentID == "" { agentID = defaultAgentID @@ -182,6 +184,7 @@ func New(options Options) *Client { baseDelay: baseDelay, maxResponseBytes: maxResponseBytes, circuitBreaker: breaker, + configErr: configErr, sleepFunc: sleepDuration, } } @@ -195,10 +198,26 @@ func NewFromEnvironment() *Client { Org: core.Env("CORE_BRAIN_ORG"), AgentID: core.Env("CORE_BRAIN_AGENT_ID"), }) - client.configErr = configErr + if configErr != nil { + client.configErr = configErr + } return client } +func validateAPIURL(apiURL string) error { + parsed, err := url.Parse(apiURL) + if err != nil || parsed.Scheme == "" || parsed.Host == "" { + return core.E("brain.client", "invalid API URL", err) + } + if parsed.Scheme == "https" { + return nil + } + if parsed.Scheme == "http" && core.Trim(core.Env(insecureBrainEnv)) == "true" { + return nil + } + return core.E("brain.client", "API URL must use https unless CORE_BRAIN_INSECURE=true", nil) +} + // WriteBrainKey stores the OpenBrain API key at ~/.claude/brain.key with owner-only permissions. func WriteBrainKey(apiKey string) error { home := core.Env("HOME") @@ -339,7 +358,11 @@ func (c *Client) doOnce(ctx context.Context, method, path, bodyString string, ha if hasBody { reader = core.NewReader(bodyString) } - request, err := http.NewRequestWithContext(ctx, method, c.requestURL(path), reader) + requestURL, err := c.requestURL(path) + if err != nil { + return nil, false, 0, false, err + } + request, err := http.NewRequestWithContext(ctx, method, requestURL, reader) if err != nil { return nil, false, 0, false, core.E("brain.client", "create request", err) } @@ -385,14 +408,15 @@ func (c *Client) doOnce(ctx context.Context, method, path, bodyString string, ha return result, false, 0, false, nil } -func (c *Client) requestURL(path string) string { - if core.HasPrefix(path, "http://") || core.HasPrefix(path, "https://") { - return path +func (c *Client) requestURL(path string) (string, error) { + parsed, err := url.Parse(path) + if err == nil && (parsed.IsAbs() || parsed.Host != "") { + return "", core.E("brain.client", "absolute request URL rejected", nil) } if !core.HasPrefix(path, "/") { path = core.Concat("/", path) } - return core.Concat(c.apiURL, path) + return core.Concat(c.apiURL, path), nil } func (c *Client) sleep(ctx context.Context, attempt int) error { diff --git a/pkg/mcp/brain/client/client_test.go b/pkg/mcp/brain/client/client_test.go index b83df10..4fdab16 100644 --- a/pkg/mcp/brain/client/client_test.go +++ b/pkg/mcp/brain/client/client_test.go @@ -18,7 +18,7 @@ import ( func TestClientRemember_Good_SendsOrgAndAuth(t *testing.T) { var gotBody map[string]any - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { t.Fatalf("expected POST, got %s", r.Method) } @@ -64,7 +64,7 @@ func TestClientRemember_Good_SendsOrgAndAuth(t *testing.T) { } func TestClientList_Good_SendsOrgURLParam(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { t.Fatalf("expected GET, got %s", r.Method) } @@ -90,9 +90,91 @@ func TestClientList_Good_SendsOrgURLParam(t *testing.T) { } } +func TestClientCall_Good_BuildsRequestAgainstAPIURL(t *testing.T) { + gotHost := "" + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + t.Fatalf("expected POST, got %s", r.Method) + } + if r.URL.Path != "/v1/brain/remember" { + t.Fatalf("expected /v1/brain/remember, got %s", r.URL.Path) + } + gotHost = r.Host + writeJSON(t, w, http.StatusOK, map[string]any{"id": "mem-1"}) + })) + defer server.Close() + + c := New(Options{ + URL: server.URL, + Key: "test-key", + HTTPClient: server.Client(), + MaxAttempts: 1, + }) + + result, err := c.Call(context.Background(), http.MethodPost, "/v1/brain/remember", map[string]any{"content": "safe"}) + if err != nil { + t.Fatalf("Call failed: %v", err) + } + if result["id"] != "mem-1" { + t.Fatalf("expected id mem-1, got %v", result["id"]) + } + if gotHost != strings.TrimPrefix(server.URL, "https://") { + t.Fatalf("expected host %s, got %s", strings.TrimPrefix(server.URL, "https://"), gotHost) + } +} + +func TestClientCall_Bad_RejectsAbsoluteRequestURL(t *testing.T) { + for _, requestPath := range []string{"http://attacker.com/leak", "https://attacker.com/leak"} { + t.Run(requestPath, func(t *testing.T) { + calls := 0 + c := New(Options{ + URL: "https://brain.test", + Key: "test-key", + HTTPClient: &http.Client{Transport: roundTripFunc(func(*http.Request) (*http.Response, error) { + calls++ + return nil, core.E("test", "unexpected HTTP request", nil) + })}, + MaxAttempts: 1, + }) + + _, err := c.Call(context.Background(), http.MethodPost, requestPath, map[string]any{"content": "leak"}) + if err == nil { + t.Fatal("expected absolute URL error") + } + if !strings.Contains(err.Error(), "absolute request URL rejected") { + t.Fatalf("expected absolute URL rejection, got %v", err) + } + if calls != 0 { + t.Fatalf("expected no HTTP requests, got %d", calls) + } + }) + } +} + +func TestClientNew_Bad_RejectsHTTPAPIURLWithoutInsecureEnv(t *testing.T) { + t.Setenv(insecureBrainEnv, "") + + c := New(Options{URL: "http://internal/", Key: "test-key"}) + if c.configErr == nil { + t.Fatal("expected insecure HTTP API URL to be rejected") + } + if !strings.Contains(c.configErr.Error(), "API URL must use https unless CORE_BRAIN_INSECURE=true") { + t.Fatalf("expected insecure API URL error, got %v", c.configErr) + } +} + +func TestClientNew_Good_AllowsHTTPAPIURLWithInsecureEnv(t *testing.T) { + t.Setenv(insecureBrainEnv, "true") + + c := New(Options{URL: "http://internal/", Key: "test-key"}) + if c.configErr != nil { + t.Fatalf("expected insecure HTTP API URL to be allowed, got %v", c.configErr) + } +} + func TestClientCall_Good_Retries503ThenSucceeds(t *testing.T) { attempts := 0 - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { attempts++ if attempts == 1 { writeJSON(t, w, http.StatusServiceUnavailable, map[string]any{"error": "down"}) @@ -119,7 +201,7 @@ func TestClientCall_Good_Retries503ThenSucceeds(t *testing.T) { func TestClientCall_Good_Retries408ThenSucceeds(t *testing.T) { attempts := 0 - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { attempts++ if attempts == 1 { writeJSON(t, w, http.StatusRequestTimeout, map[string]any{"error": "timeout"}) @@ -146,7 +228,7 @@ func TestClientCall_Good_Retries408ThenSucceeds(t *testing.T) { func TestClientCall_Good_Retries429ThenSucceeds(t *testing.T) { attempts := 0 - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { attempts++ if attempts == 1 { writeJSON(t, w, http.StatusTooManyRequests, map[string]any{"error": "rate limited"}) @@ -173,7 +255,7 @@ func TestClientCall_Good_Retries429ThenSucceeds(t *testing.T) { func TestClientCall_Good_Retries429UsingRetryAfterSeconds(t *testing.T) { attempts := 0 - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { attempts++ if attempts == 1 { w.Header().Set("Retry-After", "1") @@ -213,7 +295,7 @@ func TestClientCall_Good_Retries429UsingRetryAfterSeconds(t *testing.T) { func TestClientCall_Good_Retries429WithPastRetryAfterDateWithoutNegativeSleep(t *testing.T) { attempts := 0 - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { attempts++ if attempts == 1 { w.Header().Set("Retry-After", "Wed, 21 Oct 2015 07:28:00 GMT") @@ -253,7 +335,7 @@ func TestClientCall_Good_Retries429WithPastRetryAfterDateWithoutNegativeSleep(t func TestClientCall_Good_CapsRetryAfterDelay(t *testing.T) { attempts := 0 - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { attempts++ if attempts == 1 { w.Header().Set("Retry-After", "9999") @@ -293,7 +375,7 @@ func TestClientCall_Good_CapsRetryAfterDelay(t *testing.T) { func TestClientCall_Bad_DoesNotRetry400(t *testing.T) { attempts := 0 - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { attempts++ writeJSON(t, w, http.StatusBadRequest, map[string]any{"error": "bad request"}) })) @@ -316,7 +398,7 @@ func TestClientCall_Bad_DoesNotRetry400(t *testing.T) { func TestClientCall_Bad_Continuous503OpensCircuit(t *testing.T) { attempts := 0 - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { attempts++ writeJSON(t, w, http.StatusServiceUnavailable, map[string]any{"error": "down"}) })) @@ -352,7 +434,7 @@ func TestClientCall_Bad_Continuous503OpensCircuit(t *testing.T) { func TestClientCall_Bad_ContextCancellation(t *testing.T) { attempts := 0 - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { attempts++ writeJSON(t, w, http.StatusOK, map[string]any{"ok": true}) })) @@ -446,3 +528,9 @@ func writeJSON(t *testing.T, w http.ResponseWriter, status int, payload any) { t.Fatalf("failed to write response: %v", err) } } + +type roundTripFunc func(*http.Request) (*http.Response, error) + +func (fn roundTripFunc) RoundTrip(request *http.Request) (*http.Response, error) { + return fn(request) +}