From 2acf186925b922d4ffa908c02a71df8e75e44c7b Mon Sep 17 00:00:00 2001 From: Snider Date: Sat, 25 Apr 2026 18:13:07 +0100 Subject: [PATCH] fix(mcp/brain/client): retryableStatus 408+429 + honour Retry-After retryableStatus now treats 408 (Request Timeout) and 429 (Too Many Requests) as retryable, matching the PHP-side equivalent. Retry loop parses Retry-After from the response (seconds form OR HTTP-date), clamps to 60s max, treats past dates as zero delay. Tests cover: 408 retried, 429 retried, numeric Retry-After honoured, past-date Retry-After produces zero sleep, long Retry-After capped. Co-authored-by: Codex Closes tasks.lthn.sh/view.php?id=996 --- pkg/mcp/brain/client/client.go | 84 +++++++++++--- pkg/mcp/brain/client/client_test.go | 174 ++++++++++++++++++++++++++++ 2 files changed, 244 insertions(+), 14 deletions(-) diff --git a/pkg/mcp/brain/client/client.go b/pkg/mcp/brain/client/client.go index aff3e60..8ea8165 100644 --- a/pkg/mcp/brain/client/client.go +++ b/pkg/mcp/brain/client/client.go @@ -18,6 +18,7 @@ import ( "net/http" "net/url" "os" + "strconv" "sync" "time" @@ -36,6 +37,7 @@ const ( defaultSuccessThreshold = 1 defaultCircuitCooldown = 30 * time.Second defaultMaxResponseBytes = int64(1 << 20) + maxRetryAfterDelay = 60 * time.Second defaultRecallTopK = 10 defaultListLimit = 50 ) @@ -68,6 +70,7 @@ type Client struct { maxResponseBytes int64 circuitBreaker *CircuitBreaker configErr error + sleepFunc func(context.Context, time.Duration) error } // RememberInput is the request body for POST /v1/brain/remember. @@ -179,6 +182,7 @@ func New(options Options) *Client { baseDelay: baseDelay, maxResponseBytes: maxResponseBytes, circuitBreaker: breaker, + sleepFunc: sleepDuration, } } @@ -300,7 +304,7 @@ func (c *Client) Call(ctx context.Context, method, path string, body any) (map[s var lastErr error for attempt := 1; attempt <= c.maxAttempts; attempt++ { - payload, retryable, err := c.doOnce(ctx, method, path, bodyString, body != nil) + payload, retryable, retryAfter, hasRetryAfter, err := c.doOnce(ctx, method, path, bodyString, body != nil) if err == nil { c.circuitBreaker.recordSuccess() return payload, nil @@ -315,7 +319,13 @@ func (c *Client) Call(ctx context.Context, method, path string, body any) (map[s if c.circuitBreaker.State() == CircuitOpen || attempt == c.maxAttempts { break } - if sleepErr := c.sleep(ctx, attempt); sleepErr != nil { + var sleepErr error + if hasRetryAfter { + sleepErr = c.sleepFor(ctx, retryAfter) + } else { + sleepErr = c.sleep(ctx, attempt) + } + if sleepErr != nil { lastErr = sleepErr break } @@ -324,14 +334,14 @@ func (c *Client) Call(ctx context.Context, method, path string, body any) (map[s return nil, lastErr } -func (c *Client) doOnce(ctx context.Context, method, path, bodyString string, hasBody bool) (map[string]any, bool, error) { +func (c *Client) doOnce(ctx context.Context, method, path, bodyString string, hasBody bool) (map[string]any, bool, time.Duration, bool, error) { var reader io.Reader if hasBody { reader = core.NewReader(bodyString) } request, err := http.NewRequestWithContext(ctx, method, c.requestURL(path), reader) if err != nil { - return nil, false, core.E("brain.client", "create request", err) + return nil, false, 0, false, core.E("brain.client", "create request", err) } request.Header.Set("Accept", "application/json") request.Header.Set("Authorization", core.Concat("Bearer ", c.apiKey)) @@ -342,36 +352,37 @@ func (c *Client) doOnce(ctx context.Context, method, path, bodyString string, ha response, err := c.httpClient.Do(request) if err != nil { if ctx.Err() != nil { - return nil, false, core.E("brain.client", "request cancelled", ctx.Err()) + return nil, false, 0, false, core.E("brain.client", "request cancelled", ctx.Err()) } - return nil, true, core.E("brain.client", "request failed", err) + return nil, true, 0, false, core.E("brain.client", "request failed", err) } defer response.Body.Close() readResult := core.ReadAll(io.LimitReader(response.Body, c.maxResponseBytes+1)) if !readResult.OK { if readErr, ok := readResult.Value.(error); ok { - return nil, false, core.E("brain.client", "read response", readErr) + return nil, false, 0, false, core.E("brain.client", "read response", readErr) } - return nil, false, core.E("brain.client", "read response", nil) + return nil, false, 0, false, core.E("brain.client", "read response", nil) } raw := readResult.Value.(string) if int64(len(raw)) > c.maxResponseBytes { - return nil, false, core.E("brain.client", "response too large", nil) + return nil, false, 0, false, core.E("brain.client", "response too large", nil) } if response.StatusCode >= http.StatusBadRequest { - return nil, retryableStatus(response.StatusCode), core.E("brain.client", core.Concat("upstream returned ", response.Status, ": ", core.Trim(raw)), nil) + retryAfter, hasRetryAfter := parseRetryAfter(response.Header.Get("Retry-After"), time.Now()) + return nil, retryableStatus(response.StatusCode), retryAfter, hasRetryAfter, core.E("brain.client", core.Concat("upstream returned ", response.Status, ": ", core.Trim(raw)), nil) } result := map[string]any{} if parseResult := core.JSONUnmarshalString(raw, &result); !parseResult.OK { if parseErr, ok := parseResult.Value.(error); ok { - return nil, false, core.E("brain.client", "parse response", parseErr) + return nil, false, 0, false, core.E("brain.client", "parse response", parseErr) } - return nil, false, core.E("brain.client", "parse response", nil) + return nil, false, 0, false, core.E("brain.client", "parse response", nil) } - return result, false, nil + return result, false, 0, false, nil } func (c *Client) requestURL(path string) string { @@ -389,6 +400,20 @@ func (c *Client) sleep(ctx context.Context, attempt int) error { for i := 1; i < attempt; i++ { delay *= 3 } + return c.sleepFor(ctx, delay) +} + +func (c *Client) sleepFor(ctx context.Context, delay time.Duration) error { + if c.sleepFunc != nil { + return c.sleepFunc(ctx, delay) + } + return sleepDuration(ctx, delay) +} + +func sleepDuration(ctx context.Context, delay time.Duration) error { + if delay <= 0 { + return nil + } timer := time.NewTimer(delay) defer timer.Stop() select { @@ -493,7 +518,38 @@ func (breaker *CircuitBreaker) stateNow(now time.Time) CircuitState { } func retryableStatus(statusCode int) bool { - return statusCode >= http.StatusInternalServerError + return statusCode == http.StatusRequestTimeout || statusCode == http.StatusTooManyRequests || statusCode >= http.StatusInternalServerError +} + +func parseRetryAfter(value string, now time.Time) (time.Duration, bool) { + value = core.Trim(value) + if value == "" { + return 0, false + } + + if seconds, err := strconv.ParseInt(value, 10, 64); err == nil { + if seconds <= 0 { + return 0, true + } + maxSeconds := int64(maxRetryAfterDelay / time.Second) + if seconds > maxSeconds { + return maxRetryAfterDelay, true + } + return time.Duration(seconds) * time.Second, true + } + + retryAt, err := http.ParseTime(value) + if err != nil { + return 0, false + } + delay := retryAt.Sub(now) + if delay <= 0 { + return 0, true + } + if delay > maxRetryAfterDelay { + return maxRetryAfterDelay, true + } + return delay, true } func envOr(key, fallback string) string { diff --git a/pkg/mcp/brain/client/client_test.go b/pkg/mcp/brain/client/client_test.go index 9a11da3..b83df10 100644 --- a/pkg/mcp/brain/client/client_test.go +++ b/pkg/mcp/brain/client/client_test.go @@ -117,6 +117,180 @@ func TestClientCall_Good_Retries503ThenSucceeds(t *testing.T) { } } +func TestClientCall_Good_Retries408ThenSucceeds(t *testing.T) { + attempts := 0 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempts++ + if attempts == 1 { + writeJSON(t, w, http.StatusRequestTimeout, map[string]any{"error": "timeout"}) + return + } + writeJSON(t, w, http.StatusOK, map[string]any{"memories": []any{}}) + })) + defer server.Close() + + c := New(Options{ + URL: server.URL, + Key: "test-key", + HTTPClient: server.Client(), + MaxAttempts: 3, + BaseDelay: time.Nanosecond, + }) + if _, err := c.Recall(context.Background(), RecallInput{Query: "retry"}); err != nil { + t.Fatalf("Recall failed after retry: %v", err) + } + if attempts != 2 { + t.Fatalf("expected 2 attempts, got %d", attempts) + } +} + +func TestClientCall_Good_Retries429ThenSucceeds(t *testing.T) { + attempts := 0 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempts++ + if attempts == 1 { + writeJSON(t, w, http.StatusTooManyRequests, map[string]any{"error": "rate limited"}) + return + } + writeJSON(t, w, http.StatusOK, map[string]any{"memories": []any{}}) + })) + defer server.Close() + + c := New(Options{ + URL: server.URL, + Key: "test-key", + HTTPClient: server.Client(), + MaxAttempts: 3, + BaseDelay: time.Nanosecond, + }) + if _, err := c.Recall(context.Background(), RecallInput{Query: "retry"}); err != nil { + t.Fatalf("Recall failed after retry: %v", err) + } + if attempts != 2 { + t.Fatalf("expected 2 attempts, got %d", attempts) + } +} + +func TestClientCall_Good_Retries429UsingRetryAfterSeconds(t *testing.T) { + attempts := 0 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempts++ + if attempts == 1 { + w.Header().Set("Retry-After", "1") + writeJSON(t, w, http.StatusTooManyRequests, map[string]any{"error": "rate limited"}) + return + } + writeJSON(t, w, http.StatusOK, map[string]any{"memories": []any{}}) + })) + defer server.Close() + + c := New(Options{ + URL: server.URL, + Key: "test-key", + HTTPClient: server.Client(), + MaxAttempts: 3, + BaseDelay: time.Nanosecond, + }) + sleeps := []time.Duration{} + c.sleepFunc = func(ctx context.Context, delay time.Duration) error { + sleeps = append(sleeps, delay) + return nil + } + + if _, err := c.Recall(context.Background(), RecallInput{Query: "retry"}); err != nil { + t.Fatalf("Recall failed after retry: %v", err) + } + if attempts != 2 { + t.Fatalf("expected 2 attempts, got %d", attempts) + } + if len(sleeps) != 1 { + t.Fatalf("expected one retry sleep, got %d", len(sleeps)) + } + if sleeps[0] != time.Second { + t.Fatalf("expected Retry-After sleep of 1s, got %v", sleeps[0]) + } +} + +func TestClientCall_Good_Retries429WithPastRetryAfterDateWithoutNegativeSleep(t *testing.T) { + attempts := 0 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempts++ + if attempts == 1 { + w.Header().Set("Retry-After", "Wed, 21 Oct 2015 07:28:00 GMT") + writeJSON(t, w, http.StatusTooManyRequests, map[string]any{"error": "rate limited"}) + return + } + writeJSON(t, w, http.StatusOK, map[string]any{"memories": []any{}}) + })) + defer server.Close() + + c := New(Options{ + URL: server.URL, + Key: "test-key", + HTTPClient: server.Client(), + MaxAttempts: 3, + BaseDelay: time.Nanosecond, + }) + sleeps := []time.Duration{} + c.sleepFunc = func(ctx context.Context, delay time.Duration) error { + sleeps = append(sleeps, delay) + return nil + } + + if _, err := c.Recall(context.Background(), RecallInput{Query: "retry"}); err != nil { + t.Fatalf("Recall failed after retry: %v", err) + } + if attempts != 2 { + t.Fatalf("expected 2 attempts, got %d", attempts) + } + if len(sleeps) != 1 { + t.Fatalf("expected one retry sleep, got %d", len(sleeps)) + } + if sleeps[0] != 0 { + t.Fatalf("expected past Retry-After date to sleep zero, got %v", sleeps[0]) + } +} + +func TestClientCall_Good_CapsRetryAfterDelay(t *testing.T) { + attempts := 0 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempts++ + if attempts == 1 { + w.Header().Set("Retry-After", "9999") + writeJSON(t, w, http.StatusServiceUnavailable, map[string]any{"error": "down"}) + return + } + writeJSON(t, w, http.StatusOK, map[string]any{"memories": []any{}}) + })) + defer server.Close() + + c := New(Options{ + URL: server.URL, + Key: "test-key", + HTTPClient: server.Client(), + MaxAttempts: 3, + BaseDelay: time.Nanosecond, + }) + sleeps := []time.Duration{} + c.sleepFunc = func(ctx context.Context, delay time.Duration) error { + sleeps = append(sleeps, delay) + return nil + } + + if _, err := c.Recall(context.Background(), RecallInput{Query: "retry"}); err != nil { + t.Fatalf("Recall failed after retry: %v", err) + } + if attempts != 2 { + t.Fatalf("expected 2 attempts, got %d", attempts) + } + if len(sleeps) != 1 { + t.Fatalf("expected one retry sleep, got %d", len(sleeps)) + } + if sleeps[0] != maxRetryAfterDelay { + t.Fatalf("expected capped Retry-After sleep of %v, got %v", maxRetryAfterDelay, sleeps[0]) + } +} + func TestClientCall_Bad_DoesNotRetry400(t *testing.T) { attempts := 0 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {