feat(agentic): add content provider registry
Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
parent
75fc9d4bf4
commit
b693695e41
4 changed files with 329 additions and 29 deletions
|
|
@ -204,6 +204,59 @@ func (s *PrepSubsystem) handleContentGenerate(ctx context.Context, options core.
|
|||
return core.Result{Value: output, OK: true}
|
||||
}
|
||||
|
||||
func (s *PrepSubsystem) contentGenerateResult(ctx context.Context, input ContentGenerateInput) (ContentResult, error) {
|
||||
if err := s.validateContentProvider(input.Provider); err != nil {
|
||||
return ContentResult{}, err
|
||||
}
|
||||
|
||||
hasPrompt := core.Trim(input.Prompt) != ""
|
||||
hasBrief := core.Trim(input.BriefID) != ""
|
||||
hasTemplate := core.Trim(input.Template) != ""
|
||||
if !hasPrompt && !(hasBrief && hasTemplate) {
|
||||
return ContentResult{}, core.E("contentGenerate", "prompt or brief_id plus template is required", nil)
|
||||
}
|
||||
|
||||
body := map[string]any{}
|
||||
if hasPrompt {
|
||||
body["prompt"] = input.Prompt
|
||||
}
|
||||
if input.BriefID != "" {
|
||||
body["brief_id"] = input.BriefID
|
||||
}
|
||||
if input.Template != "" {
|
||||
body["template"] = input.Template
|
||||
}
|
||||
if input.Provider != "" {
|
||||
body["provider"] = input.Provider
|
||||
}
|
||||
if len(input.Config) > 0 {
|
||||
body["config"] = input.Config
|
||||
}
|
||||
|
||||
result := s.platformPayload(ctx, "content.generate", "POST", "/v1/content/generate", body)
|
||||
if !result.OK {
|
||||
return ContentResult{}, resultErrorValue("content.generate", result)
|
||||
}
|
||||
|
||||
return parseContentResult(payloadResourceMap(result.Value.(map[string]any), "result", "content", "generation")), nil
|
||||
}
|
||||
|
||||
func (s *PrepSubsystem) validateContentProvider(providerName string) error {
|
||||
if core.Trim(providerName) == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
provider, ok := s.providerManager().Provider(providerName)
|
||||
if !ok {
|
||||
return core.E("contentGenerate", core.Concat("unknown provider: ", providerName), nil)
|
||||
}
|
||||
if !provider.IsAvailable() {
|
||||
return core.E("contentGenerate", core.Concat("provider unavailable: ", providerName), nil)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// result := c.Action("content.batch.generate").Run(ctx, core.NewOptions(core.Option{Key: "batch_id", Value: "batch_123"}))
|
||||
func (s *PrepSubsystem) handleContentBatchGenerate(ctx context.Context, options core.Options) core.Result {
|
||||
_, output, err := s.contentBatchGenerate(ctx, nil, ContentBatchGenerateInput{
|
||||
|
|
@ -387,38 +440,13 @@ func (s *PrepSubsystem) registerContentTools(server *mcp.Server) {
|
|||
}
|
||||
|
||||
func (s *PrepSubsystem) contentGenerate(ctx context.Context, _ *mcp.CallToolRequest, input ContentGenerateInput) (*mcp.CallToolResult, ContentGenerateOutput, error) {
|
||||
hasPrompt := core.Trim(input.Prompt) != ""
|
||||
hasBrief := core.Trim(input.BriefID) != ""
|
||||
hasTemplate := core.Trim(input.Template) != ""
|
||||
if !hasPrompt && !(hasBrief && hasTemplate) {
|
||||
return nil, ContentGenerateOutput{}, core.E("contentGenerate", "prompt or brief_id plus template is required", nil)
|
||||
content, err := s.contentGenerateResult(ctx, input)
|
||||
if err != nil {
|
||||
return nil, ContentGenerateOutput{}, err
|
||||
}
|
||||
|
||||
body := map[string]any{}
|
||||
if hasPrompt {
|
||||
body["prompt"] = input.Prompt
|
||||
}
|
||||
if input.BriefID != "" {
|
||||
body["brief_id"] = input.BriefID
|
||||
}
|
||||
if input.Template != "" {
|
||||
body["template"] = input.Template
|
||||
}
|
||||
if input.Provider != "" {
|
||||
body["provider"] = input.Provider
|
||||
}
|
||||
if len(input.Config) > 0 {
|
||||
body["config"] = input.Config
|
||||
}
|
||||
|
||||
result := s.platformPayload(ctx, "content.generate", "POST", "/v1/content/generate", body)
|
||||
if !result.OK {
|
||||
return nil, ContentGenerateOutput{}, resultErrorValue("content.generate", result)
|
||||
}
|
||||
|
||||
return nil, ContentGenerateOutput{
|
||||
Success: true,
|
||||
Result: parseContentResult(payloadResourceMap(result.Value.(map[string]any), "result", "content", "generation")),
|
||||
Result: content,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
|
@ -426,6 +454,9 @@ func (s *PrepSubsystem) contentBatchGenerate(ctx context.Context, _ *mcp.CallToo
|
|||
if core.Trim(input.BatchID) == "" {
|
||||
return nil, ContentBatchOutput{}, core.E("contentBatchGenerate", "batch_id is required", nil)
|
||||
}
|
||||
if err := s.validateContentProvider(input.Provider); err != nil {
|
||||
return nil, ContentBatchOutput{}, err
|
||||
}
|
||||
|
||||
body := map[string]any{
|
||||
"batch_id": input.BatchID,
|
||||
|
|
@ -563,6 +594,9 @@ func (s *PrepSubsystem) contentFromPlan(ctx context.Context, _ *mcp.CallToolRequ
|
|||
if core.Trim(input.PlanSlug) == "" {
|
||||
return nil, ContentFromPlanOutput{}, core.E("contentFromPlan", "plan_slug is required", nil)
|
||||
}
|
||||
if err := s.validateContentProvider(input.Provider); err != nil {
|
||||
return nil, ContentFromPlanOutput{}, err
|
||||
}
|
||||
|
||||
body := map[string]any{
|
||||
"plan_slug": input.PlanSlug,
|
||||
|
|
|
|||
|
|
@ -37,6 +37,7 @@ type PrepSubsystem struct {
|
|||
frozen bool
|
||||
backoff map[string]time.Time
|
||||
failCount map[string]int
|
||||
providers *ProviderManager
|
||||
workspaces *core.Registry[*WorkspaceStatus]
|
||||
}
|
||||
|
||||
|
|
|
|||
220
pkg/agentic/provider_manager.go
Normal file
220
pkg/agentic/provider_manager.go
Normal file
|
|
@ -0,0 +1,220 @@
|
|||
// SPDX-License-Identifier: EUPL-1.2
|
||||
|
||||
package agentic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sort"
|
||||
|
||||
core "dappco.re/go/core"
|
||||
)
|
||||
|
||||
// provider := agentic.NewProviderManager(nil).Provider("claude")
|
||||
//
|
||||
// core.Println(provider.Name()) // "claude"
|
||||
type AgenticProviderInterface interface {
|
||||
Generate(context.Context, string, map[string]any) (string, error)
|
||||
Stream(context.Context, string, map[string]any, func(string)) error
|
||||
Name() string
|
||||
DefaultModel() string
|
||||
IsAvailable() bool
|
||||
}
|
||||
|
||||
// manager := agentic.NewProviderManager(nil)
|
||||
// core.Println(manager.Names()) // ["claude", "gemini", "openai"]
|
||||
type ProviderManager struct {
|
||||
providers map[string]AgenticProviderInterface
|
||||
}
|
||||
|
||||
// providerManager returns the lazily initialised provider registry for content generation.
|
||||
//
|
||||
// manager := s.providerManager()
|
||||
// core.Println(manager.Names())
|
||||
func (s *PrepSubsystem) providerManager() *ProviderManager {
|
||||
if s == nil {
|
||||
return NewProviderManager(nil)
|
||||
}
|
||||
if s.providers != nil {
|
||||
return s.providers
|
||||
}
|
||||
|
||||
s.providers = NewProviderManager(func(ctx context.Context, prompt string, options map[string]any) (string, error) {
|
||||
config := anyMapValue(options["config"])
|
||||
if model := contentMapStringValue(options, "model"); model != "" {
|
||||
if config == nil {
|
||||
config = map[string]any{}
|
||||
}
|
||||
config["model"] = model
|
||||
}
|
||||
input := ContentGenerateInput{
|
||||
Prompt: prompt,
|
||||
Provider: contentMapStringValue(options, "provider"),
|
||||
Config: config,
|
||||
}
|
||||
if template := contentMapStringValue(options, "template"); template != "" {
|
||||
input.Template = template
|
||||
}
|
||||
if briefID := contentMapStringValue(options, "brief_id", "briefId"); briefID != "" {
|
||||
input.BriefID = briefID
|
||||
}
|
||||
result, err := s.contentGenerateResult(ctx, input)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return result.Content, nil
|
||||
})
|
||||
|
||||
return s.providers
|
||||
}
|
||||
|
||||
// NewProviderManager registers the built-in content providers.
|
||||
//
|
||||
// manager := agentic.NewProviderManager(func(ctx context.Context, prompt string, options map[string]any) (string, error) {
|
||||
// return "Draft ready", nil
|
||||
// })
|
||||
func NewProviderManager(generate ProviderGenerateFunc) *ProviderManager {
|
||||
manager := &ProviderManager{
|
||||
providers: make(map[string]AgenticProviderInterface),
|
||||
}
|
||||
|
||||
manager.Register(newContentProvider("claude", "claude-3.7-sonnet", true, generate))
|
||||
manager.Register(newContentProvider("gemini", "gemini-2.5-pro", true, generate))
|
||||
manager.Register(newContentProvider("openai", "gpt-5.4", true, generate))
|
||||
|
||||
return manager
|
||||
}
|
||||
|
||||
// Generate returns the generated text from a registered provider.
|
||||
//
|
||||
// provider, _ := manager.Provider("claude")
|
||||
// text, _ := provider.Generate(ctx, "Draft a release note", map[string]any{"temperature": 0.2})
|
||||
type ProviderGenerateFunc func(context.Context, string, map[string]any) (string, error)
|
||||
|
||||
// Stream sends provider output to the callback as it arrives.
|
||||
//
|
||||
// provider, _ := manager.Provider("claude")
|
||||
// _ = provider.Stream(ctx, "Draft a release note", nil, func(token string) { core.Print(nil, token) })
|
||||
type ProviderStreamFunc func(context.Context, string, map[string]any, func(string)) error
|
||||
|
||||
type contentProvider struct {
|
||||
name string
|
||||
defaultModel string
|
||||
available bool
|
||||
generate ProviderGenerateFunc
|
||||
stream ProviderStreamFunc
|
||||
}
|
||||
|
||||
func newContentProvider(name, defaultModel string, available bool, generate ProviderGenerateFunc) *contentProvider {
|
||||
provider := &contentProvider{
|
||||
name: name,
|
||||
defaultModel: defaultModel,
|
||||
available: available,
|
||||
generate: generate,
|
||||
}
|
||||
provider.stream = func(ctx context.Context, prompt string, options map[string]any, onToken func(string)) error {
|
||||
content, err := provider.Generate(ctx, prompt, options)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if onToken != nil {
|
||||
onToken(content)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return provider
|
||||
}
|
||||
|
||||
func (p *contentProvider) Generate(ctx context.Context, prompt string, options map[string]any) (string, error) {
|
||||
if p.generate == nil {
|
||||
return "", core.E("provider.generate", core.Concat("provider not configured: ", p.name), nil)
|
||||
}
|
||||
|
||||
optionsCopy := map[string]any{}
|
||||
for key, value := range options {
|
||||
optionsCopy[key] = value
|
||||
}
|
||||
if optionsCopy["provider"] == nil {
|
||||
optionsCopy["provider"] = p.name
|
||||
}
|
||||
if optionsCopy["model"] == nil && p.defaultModel != "" {
|
||||
optionsCopy["model"] = p.defaultModel
|
||||
}
|
||||
|
||||
return p.generate(ctx, prompt, optionsCopy)
|
||||
}
|
||||
|
||||
func (p *contentProvider) Stream(ctx context.Context, prompt string, options map[string]any, onToken func(string)) error {
|
||||
if p.stream == nil {
|
||||
return core.E("provider.stream", core.Concat("provider not configured: ", p.name), nil)
|
||||
}
|
||||
return p.stream(ctx, prompt, options, onToken)
|
||||
}
|
||||
|
||||
func (p *contentProvider) Name() string {
|
||||
return p.name
|
||||
}
|
||||
|
||||
func (p *contentProvider) DefaultModel() string {
|
||||
return p.defaultModel
|
||||
}
|
||||
|
||||
func (p *contentProvider) IsAvailable() bool {
|
||||
return p.available
|
||||
}
|
||||
|
||||
// Register adds or replaces a provider in the registry.
|
||||
//
|
||||
// manager.Register(newContentProvider("claude", "claude-3.7-sonnet", true, generate))
|
||||
func (m *ProviderManager) Register(provider AgenticProviderInterface) {
|
||||
if m == nil || provider == nil {
|
||||
return
|
||||
}
|
||||
if m.providers == nil {
|
||||
m.providers = make(map[string]AgenticProviderInterface)
|
||||
}
|
||||
m.providers[core.Lower(core.Trim(provider.Name()))] = provider
|
||||
}
|
||||
|
||||
// Provider returns a registered provider by name.
|
||||
//
|
||||
// provider, ok := manager.Provider("openai")
|
||||
func (m *ProviderManager) Provider(name string) (AgenticProviderInterface, bool) {
|
||||
if m == nil {
|
||||
return nil, false
|
||||
}
|
||||
provider, ok := m.providers[core.Lower(core.Trim(name))]
|
||||
return provider, ok
|
||||
}
|
||||
|
||||
// Names returns the registered provider names in deterministic order.
|
||||
//
|
||||
// core.Println(manager.Names()) // ["claude", "gemini", "openai"]
|
||||
func (m *ProviderManager) Names() []string {
|
||||
if m == nil || len(m.providers) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
names := make([]string, 0, len(m.providers))
|
||||
for name := range m.providers {
|
||||
names = append(names, name)
|
||||
}
|
||||
sort.Strings(names)
|
||||
return names
|
||||
}
|
||||
|
||||
// DefaultProvider returns the first registered provider that is available.
|
||||
//
|
||||
// provider := manager.DefaultProvider()
|
||||
func (m *ProviderManager) DefaultProvider() AgenticProviderInterface {
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, name := range m.Names() {
|
||||
if provider, ok := m.Provider(name); ok && provider.IsAvailable() {
|
||||
return provider
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
45
pkg/agentic/provider_manager_test.go
Normal file
45
pkg/agentic/provider_manager_test.go
Normal file
|
|
@ -0,0 +1,45 @@
|
|||
// SPDX-License-Identifier: EUPL-1.2
|
||||
|
||||
package agentic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestProviderManager_NewProviderManager_Good_RegistersBuiltIns(t *testing.T) {
|
||||
manager := NewProviderManager(func(context.Context, string, map[string]any) (string, error) {
|
||||
return "Draft ready", nil
|
||||
})
|
||||
|
||||
require.NotNil(t, manager)
|
||||
assert.Equal(t, []string{"claude", "gemini", "openai"}, manager.Names())
|
||||
|
||||
provider, ok := manager.Provider("claude")
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "claude", provider.Name())
|
||||
assert.Equal(t, "claude-3.7-sonnet", provider.DefaultModel())
|
||||
|
||||
text, err := provider.Generate(context.Background(), "Write a release note", nil)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Draft ready", text)
|
||||
}
|
||||
|
||||
func TestProviderManager_Provider_Bad_UnknownNameReturnsFalse(t *testing.T) {
|
||||
manager := NewProviderManager(nil)
|
||||
|
||||
provider, ok := manager.Provider("unknown")
|
||||
assert.False(t, ok)
|
||||
assert.Nil(t, provider)
|
||||
}
|
||||
|
||||
func TestProviderManager_ContentProvider_Ugly_NoGeneratorReturnsError(t *testing.T) {
|
||||
provider := newContentProvider("claude", "claude-3.7-sonnet", true, nil)
|
||||
|
||||
_, err := provider.Generate(context.Background(), "Draft a release note", nil)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "provider not configured")
|
||||
}
|
||||
Loading…
Add table
Reference in a new issue