Validate chat settings and image model compatibility
This commit is contained in:
parent
f496454781
commit
a79fd7bf34
2 changed files with 218 additions and 8 deletions
|
|
@ -461,6 +461,9 @@ func (s *Service) now() time.Time {
|
|||
}
|
||||
|
||||
func (s *Service) saveSettings(settings ChatSettings) error {
|
||||
if err := s.validateSettings(settings); err != nil {
|
||||
return err
|
||||
}
|
||||
payload := core.JSONMarshalString(settings)
|
||||
return s.store.Set(settingsGroup, settingsKey, payload)
|
||||
}
|
||||
|
|
@ -479,6 +482,9 @@ func (s *Service) loadSettings() ChatSettings {
|
|||
}
|
||||
|
||||
func (s *Service) selectModel(input selectModelInput) (ChatSettings, error) {
|
||||
if err := s.validateModelName(input.Model); err != nil {
|
||||
return ChatSettings{}, err
|
||||
}
|
||||
settings := s.loadSettings()
|
||||
settings.DefaultModel = input.Model
|
||||
if err := s.saveSettings(settings); err != nil {
|
||||
|
|
@ -504,6 +510,9 @@ func (s *Service) selectModel(input selectModelInput) (ChatSettings, error) {
|
|||
}
|
||||
|
||||
func (s *Service) saveConversation(conv Conversation) (Conversation, error) {
|
||||
if err := s.validateConversation(conv); err != nil {
|
||||
return Conversation{}, err
|
||||
}
|
||||
if conv.CreatedAt.IsZero() {
|
||||
conv.CreatedAt = s.now()
|
||||
}
|
||||
|
|
@ -831,6 +840,9 @@ func (s *Service) send(ctx context.Context, input sendInput) (Conversation, erro
|
|||
for toolRound := 0; toolRound < 3; toolRound++ {
|
||||
effectiveSettings := s.mergedSettings(settings, conv.Settings)
|
||||
conv.Model = s.resolveModel(conv.Model, effectiveSettings.DefaultModel)
|
||||
if err := s.validateAttachmentsForModel(conv.Model, attachmentsForConversationTurn(conv.Messages)); err != nil {
|
||||
return conv, err
|
||||
}
|
||||
|
||||
assistantMessage, err := s.streamAssistant(ctx, conv, effectiveSettings)
|
||||
if err != nil {
|
||||
|
|
@ -1119,6 +1131,131 @@ func (s *Service) discoverModels() []ModelEntry {
|
|||
return results
|
||||
}
|
||||
|
||||
func (s *Service) validateSettings(settings ChatSettings) error {
|
||||
if settings.Temperature < 0 || settings.Temperature > 2 {
|
||||
return coreerr.E("chat.settings.save", "temperature must be between 0.0 and 2.0", nil)
|
||||
}
|
||||
if settings.TopP < 0 || settings.TopP > 1 {
|
||||
return coreerr.E("chat.settings.save", "top_p must be between 0.0 and 1.0", nil)
|
||||
}
|
||||
if settings.TopK < 0 || settings.TopK > 200 {
|
||||
return coreerr.E("chat.settings.save", "top_k must be between 0 and 200", nil)
|
||||
}
|
||||
if settings.MaxTokens < 64 || settings.MaxTokens > 32768 {
|
||||
return coreerr.E("chat.settings.save", "max_tokens must be between 64 and 32768", nil)
|
||||
}
|
||||
if !validContextWindow(settings.ContextWindow) {
|
||||
return coreerr.E("chat.settings.save", "context_window must be one of 2048, 4096, 8192, 16384, or 32768", nil)
|
||||
}
|
||||
if err := s.validateOptionalModelName(settings.DefaultModel); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func validContextWindow(value int) bool {
|
||||
switch value {
|
||||
case 2048, 4096, 8192, 16384, 32768:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) validateConversation(conv Conversation) error {
|
||||
if strings.TrimSpace(conv.ID) == "" {
|
||||
return coreerr.E("chat.saveConversation", "conversation id is required", nil)
|
||||
}
|
||||
if err := s.validateOptionalModelName(conv.Model); err != nil {
|
||||
return err
|
||||
}
|
||||
if conv.Settings != nil {
|
||||
if err := s.validateSettings(*conv.Settings); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
for _, message := range conv.Messages {
|
||||
if err := validateMessageAttachments(message); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if err := s.validateAttachmentsForModel(s.resolveModel(conv.Model, s.loadSettings().DefaultModel), attachmentsForConversationTurn(conv.Messages)); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) validateModelName(name string) error {
|
||||
if strings.TrimSpace(name) == "" {
|
||||
return coreerr.E("chat.selectModel", "model is required", nil)
|
||||
}
|
||||
if len(s.discoverModels()) == 0 {
|
||||
return nil
|
||||
}
|
||||
if _, ok := s.findModel(name); ok {
|
||||
return nil
|
||||
}
|
||||
return coreerr.E("chat.selectModel", "model is not available: "+name, nil)
|
||||
}
|
||||
|
||||
func (s *Service) validateOptionalModelName(name string) error {
|
||||
if strings.TrimSpace(name) == "" {
|
||||
return nil
|
||||
}
|
||||
if len(s.discoverModels()) == 0 || strings.EqualFold(strings.TrimSpace(name), "default") {
|
||||
return nil
|
||||
}
|
||||
if _, ok := s.findModel(name); ok {
|
||||
return nil
|
||||
}
|
||||
return coreerr.E("chat.model", "model is not available: "+name, nil)
|
||||
}
|
||||
|
||||
func (s *Service) findModel(name string) (ModelEntry, bool) {
|
||||
for _, model := range s.discoverModels() {
|
||||
if strings.EqualFold(model.Name, strings.TrimSpace(name)) {
|
||||
return model, true
|
||||
}
|
||||
}
|
||||
return ModelEntry{}, false
|
||||
}
|
||||
|
||||
func (s *Service) validateAttachmentsForModel(modelName string, attachments []ImageAttachment) error {
|
||||
if len(attachments) == 0 {
|
||||
return nil
|
||||
}
|
||||
model, ok := s.findModel(modelName)
|
||||
if !ok {
|
||||
return coreerr.E("chat.send", "image attachments require a discovered vision-capable model", nil)
|
||||
}
|
||||
if !model.SupportsVision {
|
||||
return coreerr.E("chat.send", "selected model does not support image input: "+model.Name, nil)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateMessageAttachments(message ChatMessage) error {
|
||||
for _, attachment := range message.Attachments {
|
||||
if err := validateImageAttachment(attachment); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func attachmentsForConversationTurn(messages []ChatMessage) []ImageAttachment {
|
||||
if len(messages) == 0 {
|
||||
return nil
|
||||
}
|
||||
for index := len(messages) - 1; index >= 0; index-- {
|
||||
if messages[index].Role != "user" {
|
||||
continue
|
||||
}
|
||||
return messages[index].Attachments
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func discoverModelsOnDisk(root string) []ModelEntry {
|
||||
if strings.TrimSpace(root) == "" {
|
||||
return nil
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ import (
|
|||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
|
@ -37,25 +38,39 @@ func (m *mockToolExecutor) CallTool(_ context.Context, name string, arguments ma
|
|||
return `{"mode":"left-right"}`, nil
|
||||
}
|
||||
|
||||
func newChatCore(t *testing.T, handler http.HandlerFunc, toolExecutor ToolExecutor) *core.Core {
|
||||
func newChatCore(t *testing.T, handler http.HandlerFunc, toolExecutor ToolExecutor, optionFns ...func(*Options)) *core.Core {
|
||||
t.Helper()
|
||||
server := httptest.NewServer(handler)
|
||||
t.Cleanup(server.Close)
|
||||
|
||||
options := []func(*Options){
|
||||
func(o *Options) { o.APIURL = server.URL },
|
||||
func(o *Options) { o.StorePath = filepath.Join(t.TempDir(), "chat.db") },
|
||||
func(o *Options) { o.ToolExecutor = toolExecutor },
|
||||
func(o *Options) { o.Now = func() time.Time { return time.Unix(1_700_000_000, 0).UTC() } },
|
||||
func(o *Options) { o.ModelRoots = nil },
|
||||
}
|
||||
options = append(options, optionFns...)
|
||||
|
||||
c := core.New(
|
||||
core.WithService(Register(
|
||||
func(o *Options) { o.APIURL = server.URL },
|
||||
func(o *Options) { o.StorePath = filepath.Join(t.TempDir(), "chat.db") },
|
||||
func(o *Options) { o.ToolExecutor = toolExecutor },
|
||||
func(o *Options) { o.Now = func() time.Time { return time.Unix(1_700_000_000, 0).UTC() } },
|
||||
func(o *Options) { o.ModelRoots = nil },
|
||||
)),
|
||||
core.WithService(Register(options...)),
|
||||
core.WithServiceLock(),
|
||||
)
|
||||
require.True(t, c.ServiceStartup(context.Background(), nil).OK)
|
||||
return c
|
||||
}
|
||||
|
||||
func createDiscoveredModelRoot(t *testing.T, name, architecture string) string {
|
||||
t.Helper()
|
||||
root := t.TempDir()
|
||||
modelDir := filepath.Join(root, name)
|
||||
require.NoError(t, os.MkdirAll(modelDir, 0o755))
|
||||
configJSON := `{"model_type":"` + architecture + `","quantization":{"bits":4,"group_size":32}}`
|
||||
require.NoError(t, os.WriteFile(filepath.Join(modelDir, "config.json"), []byte(configJSON), 0o644))
|
||||
require.NoError(t, os.WriteFile(filepath.Join(modelDir, "weights.safetensors"), []byte("fake"), 0o644))
|
||||
return root
|
||||
}
|
||||
|
||||
func TestService_Good_SendAndHistory(t *testing.T) {
|
||||
c := newChatCore(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
|
|
@ -155,3 +170,61 @@ func TestService_Good_SettingsDefaults(t *testing.T) {
|
|||
require.True(t, actionResult.OK)
|
||||
assert.Equal(t, DefaultSettings(), actionResult.Value.(ChatSettings))
|
||||
}
|
||||
|
||||
func TestService_Bad_SettingsRejectOutOfRangeValues(t *testing.T) {
|
||||
c := newChatCore(t, func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
_, _ = io.WriteString(w, "data: [DONE]\n\n")
|
||||
}, &mockToolExecutor{})
|
||||
|
||||
result := c.Action("gui.chat.settings.save").Run(context.Background(), core.NewOptions(
|
||||
core.Option{Key: "temperature", Value: float32(2.5)},
|
||||
core.Option{Key: "top_p", Value: float32(0.95)},
|
||||
core.Option{Key: "top_k", Value: 64},
|
||||
core.Option{Key: "max_tokens", Value: 2048},
|
||||
core.Option{Key: "context_window", Value: 8192},
|
||||
core.Option{Key: "system_prompt", Value: "You are a helpful assistant."},
|
||||
))
|
||||
require.False(t, result.OK)
|
||||
require.Error(t, result.Value.(error))
|
||||
assert.Contains(t, result.Value.(error).Error(), "temperature must be between 0.0 and 2.0")
|
||||
}
|
||||
|
||||
func TestService_Bad_SelectModelRejectsUnknownModel(t *testing.T) {
|
||||
modelRoot := createDiscoveredModelRoot(t, "lemer", "gemma3")
|
||||
c := newChatCore(t, func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
_, _ = io.WriteString(w, "data: [DONE]\n\n")
|
||||
}, &mockToolExecutor{}, func(o *Options) { o.ModelRoots = []string{modelRoot} })
|
||||
|
||||
result := c.Action("gui.chat.selectModel").Run(context.Background(), core.NewOptions(
|
||||
core.Option{Key: "model", Value: "missing-model"},
|
||||
))
|
||||
require.False(t, result.OK)
|
||||
require.Error(t, result.Value.(error))
|
||||
assert.Contains(t, result.Value.(error).Error(), "model is not available")
|
||||
}
|
||||
|
||||
func TestService_Bad_SendImageRejectsNonVisionModel(t *testing.T) {
|
||||
modelRoot := createDiscoveredModelRoot(t, "lemma", "qwen3")
|
||||
c := newChatCore(t, func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
_, _ = io.WriteString(w, "data: [DONE]\n\n")
|
||||
}, &mockToolExecutor{}, func(o *Options) { o.ModelRoots = []string{modelRoot} })
|
||||
|
||||
attach := c.Action("gui.chat.attachImage").Run(context.Background(), core.NewOptions(
|
||||
core.Option{Key: "filename", Value: "photo.png"},
|
||||
core.Option{Key: "mime_type", Value: "image/png"},
|
||||
core.Option{Key: "data", Value: "ZmFrZQ=="},
|
||||
core.Option{Key: "width", Value: 32},
|
||||
core.Option{Key: "height", Value: 32},
|
||||
))
|
||||
require.True(t, attach.OK)
|
||||
|
||||
send := c.Action("gui.chat.send").Run(context.Background(), core.NewOptions(
|
||||
core.Option{Key: "content", Value: "Describe this image"},
|
||||
))
|
||||
require.False(t, send.OK)
|
||||
require.Error(t, send.Value.(error))
|
||||
assert.Contains(t, send.Value.(error).Error(), "does not support image input")
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue