feat: use GGUF metadata for model type and context window auto-detection

Replaces filename-based guessModelType with GGUF header parsing.
Caps default context at 4096 to prevent VRAM exhaustion on models
with 128K+ native context.

Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
Claude 2026-02-19 22:23:07 +00:00
parent af235653ca
commit 2c77f6f968
No known key found for this signature in database
GPG key ID: AF404715446AEB41
2 changed files with 17 additions and 34 deletions

View file

@ -3,11 +3,11 @@
package rocm
import (
"fmt"
"os"
"path/filepath"
"strings"
"forge.lthn.ai/core/go-inference"
"forge.lthn.ai/core/go-rocm/internal/gguf"
)
// rocmBackend implements inference.Backend for AMD ROCm GPUs.
@ -28,6 +28,9 @@ func (b *rocmBackend) Available() bool {
}
// LoadModel loads a GGUF model onto the AMD GPU via llama-server.
// Model architecture is read from GGUF metadata (replacing filename-based guessing).
// If no context length is specified, defaults to min(model_context_length, 4096)
// to prevent VRAM exhaustion on models with 128K+ native context.
func (b *rocmBackend) LoadModel(path string, opts ...inference.LoadOption) (inference.TextModel, error) {
cfg := inference.ApplyLoadOpts(opts)
@ -36,26 +39,23 @@ func (b *rocmBackend) LoadModel(path string, opts ...inference.LoadOption) (infe
return nil, err
}
srv, err := startServer(binary, path, cfg.GPULayers, cfg.ContextLen)
meta, err := gguf.ReadMetadata(path)
if err != nil {
return nil, fmt.Errorf("rocm: read model metadata: %w", err)
}
ctxLen := cfg.ContextLen
if ctxLen == 0 && meta.ContextLength > 0 {
ctxLen = int(min(meta.ContextLength, 4096))
}
srv, err := startServer(binary, path, cfg.GPULayers, ctxLen)
if err != nil {
return nil, err
}
return &rocmModel{
srv: srv,
modelType: guessModelType(path),
modelType: meta.Architecture,
}, nil
}
// guessModelType extracts a model type hint from the filename.
// Hyphens are stripped so that "Llama-3" matches "llama3".
func guessModelType(path string) string {
name := strings.ToLower(filepath.Base(path))
name = strings.ReplaceAll(name, "-", "")
for _, arch := range []string{"gemma3", "gemma2", "gemma", "qwen3", "qwen2", "qwen", "llama3", "llama", "mistral", "phi"} {
if strings.Contains(name, arch) {
return arch
}
}
return "unknown"
}

View file

@ -81,23 +81,6 @@ func TestAvailable(t *testing.T) {
assert.True(t, b.Available())
}
func TestGuessModelType(t *testing.T) {
tests := []struct {
path string
expected string
}{
{"/data/lem/gguf/LEK-Gemma3-4B-Q4_K_M.gguf", "gemma3"},
{"/data/models/Qwen3-8B-Q4_K_M.gguf", "qwen3"},
{"/data/models/Llama-3.1-8B-Q4_K_M.gguf", "llama3"},
{"/data/models/Mistral-7B-v0.3-Q4_K_M.gguf", "mistral"},
{"/data/models/random-model.gguf", "unknown"},
}
for _, tt := range tests {
t.Run(tt.expected, func(t *testing.T) {
assert.Equal(t, tt.expected, guessModelType(tt.path))
})
}
}
func TestServerAlive_Running(t *testing.T) {
s := &server{exited: make(chan struct{})}