diff --git a/pkg/agentic/provider_manager.go b/pkg/agentic/provider_manager.go index 7bb18bc..9b4ed83 100644 --- a/pkg/agentic/provider_manager.go +++ b/pkg/agentic/provider_manager.go @@ -5,6 +5,7 @@ package agentic import ( "context" "sort" + "time" core "dappco.re/go/core" ) @@ -26,6 +27,11 @@ type ProviderManager struct { providers map[string]AgenticProviderInterface } +var providerRetryBaseDelay = 100 * time.Millisecond +var providerSleep = time.Sleep + +const providerRetryAttempts = 3 + // manager := s.providerManager() // core.Println(manager.Names()) // ["claude", "gemini", "openai"] func (s *PrepSubsystem) providerManager() *ProviderManager { @@ -125,18 +131,44 @@ func (p *contentProvider) Generate(ctx context.Context, prompt string, options m return "", core.E("provider.generate", core.Concat("provider not configured: ", p.name), nil) } - optionsCopy := map[string]any{} - for key, value := range options { - optionsCopy[key] = value - } - if optionsCopy["provider"] == nil { - optionsCopy["provider"] = p.name - } - if optionsCopy["model"] == nil && p.defaultModel != "" { - optionsCopy["model"] = p.defaultModel + var lastErr error + delay := providerRetryBaseDelay + for attempt := 1; attempt <= providerRetryAttempts; attempt++ { + optionsCopy := map[string]any{} + for key, value := range options { + optionsCopy[key] = value + } + if optionsCopy["provider"] == nil { + optionsCopy["provider"] = p.name + } + if optionsCopy["model"] == nil && p.defaultModel != "" { + optionsCopy["model"] = p.defaultModel + } + + content, err := p.generate(ctx, prompt, optionsCopy) + if err == nil { + return content, nil + } + lastErr = err + if attempt == providerRetryAttempts { + break + } + if ctx != nil { + select { + case <-ctx.Done(): + return "", ctx.Err() + default: + } + } + if delay > 0 { + providerSleep(delay) + delay *= 2 + continue + } + delay *= 2 } - return p.generate(ctx, prompt, optionsCopy) + return "", lastErr } func (p *contentProvider) Stream(ctx context.Context, prompt string, options map[string]any, onToken func(string)) error { diff --git a/pkg/agentic/provider_manager_test.go b/pkg/agentic/provider_manager_test.go index 0330534..fc4e4d6 100644 --- a/pkg/agentic/provider_manager_test.go +++ b/pkg/agentic/provider_manager_test.go @@ -5,7 +5,9 @@ package agentic import ( "context" "testing" + "time" + core "dappco.re/go/core" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -43,3 +45,36 @@ func TestProviderManager_ContentProvider_Ugly_NoGeneratorReturnsError(t *testing require.Error(t, err) assert.Contains(t, err.Error(), "provider not configured") } + +func TestProviderManager_ContentProvider_Good_RetriesWithExponentialBackoff(t *testing.T) { + originalSleep := providerSleep + originalDelay := providerRetryBaseDelay + defer func() { + providerSleep = originalSleep + providerRetryBaseDelay = originalDelay + }() + + var delays []time.Duration + providerSleep = func(delay time.Duration) { + delays = append(delays, delay) + } + providerRetryBaseDelay = 50 * time.Millisecond + + attempts := 0 + provider := newContentProvider("claude", "claude-3.7-sonnet", true, func(_ context.Context, _ string, options map[string]any) (string, error) { + attempts++ + if attempts < 3 { + return "", core.E("test.generate", "transient failure", nil) + } + + assert.Equal(t, "claude", options["provider"]) + assert.Equal(t, "claude-3.7-sonnet", options["model"]) + return "Draft ready", nil + }) + + text, err := provider.Generate(context.Background(), "Write a release note", nil) + require.NoError(t, err) + assert.Equal(t, "Draft ready", text) + assert.Equal(t, 3, attempts) + assert.Equal(t, []time.Duration{50 * time.Millisecond, 100 * time.Millisecond}, delays) +}