diff --git a/backend.go b/backend.go index c796d2c..b32c6b6 100644 --- a/backend.go +++ b/backend.go @@ -3,11 +3,11 @@ package rocm import ( + "fmt" "os" - "path/filepath" - "strings" "forge.lthn.ai/core/go-inference" + "forge.lthn.ai/core/go-rocm/internal/gguf" ) // rocmBackend implements inference.Backend for AMD ROCm GPUs. @@ -28,6 +28,9 @@ func (b *rocmBackend) Available() bool { } // LoadModel loads a GGUF model onto the AMD GPU via llama-server. +// Model architecture is read from GGUF metadata (replacing filename-based guessing). +// If no context length is specified, defaults to min(model_context_length, 4096) +// to prevent VRAM exhaustion on models with 128K+ native context. func (b *rocmBackend) LoadModel(path string, opts ...inference.LoadOption) (inference.TextModel, error) { cfg := inference.ApplyLoadOpts(opts) @@ -36,26 +39,23 @@ func (b *rocmBackend) LoadModel(path string, opts ...inference.LoadOption) (infe return nil, err } - srv, err := startServer(binary, path, cfg.GPULayers, cfg.ContextLen) + meta, err := gguf.ReadMetadata(path) + if err != nil { + return nil, fmt.Errorf("rocm: read model metadata: %w", err) + } + + ctxLen := cfg.ContextLen + if ctxLen == 0 && meta.ContextLength > 0 { + ctxLen = int(min(meta.ContextLength, 4096)) + } + + srv, err := startServer(binary, path, cfg.GPULayers, ctxLen) if err != nil { return nil, err } return &rocmModel{ srv: srv, - modelType: guessModelType(path), + modelType: meta.Architecture, }, nil } - -// guessModelType extracts a model type hint from the filename. -// Hyphens are stripped so that "Llama-3" matches "llama3". -func guessModelType(path string) string { - name := strings.ToLower(filepath.Base(path)) - name = strings.ReplaceAll(name, "-", "") - for _, arch := range []string{"gemma3", "gemma2", "gemma", "qwen3", "qwen2", "qwen", "llama3", "llama", "mistral", "phi"} { - if strings.Contains(name, arch) { - return arch - } - } - return "unknown" -} diff --git a/server_test.go b/server_test.go index a3592fe..5d73045 100644 --- a/server_test.go +++ b/server_test.go @@ -81,23 +81,6 @@ func TestAvailable(t *testing.T) { assert.True(t, b.Available()) } -func TestGuessModelType(t *testing.T) { - tests := []struct { - path string - expected string - }{ - {"/data/lem/gguf/LEK-Gemma3-4B-Q4_K_M.gguf", "gemma3"}, - {"/data/models/Qwen3-8B-Q4_K_M.gguf", "qwen3"}, - {"/data/models/Llama-3.1-8B-Q4_K_M.gguf", "llama3"}, - {"/data/models/Mistral-7B-v0.3-Q4_K_M.gguf", "mistral"}, - {"/data/models/random-model.gguf", "unknown"}, - } - for _, tt := range tests { - t.Run(tt.expected, func(t *testing.T) { - assert.Equal(t, tt.expected, guessModelType(tt.path)) - }) - } -} func TestServerAlive_Running(t *testing.T) { s := &server{exited: make(chan struct{})}