feat: use GGUF metadata for model type and context window auto-detection
Replaces filename-based guessModelType with GGUF header parsing. Caps default context at 4096 to prevent VRAM exhaustion on models with 128K+ native context. Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
parent
af235653ca
commit
2c77f6f968
2 changed files with 17 additions and 34 deletions
34
backend.go
34
backend.go
|
|
@ -3,11 +3,11 @@
|
|||
package rocm
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"forge.lthn.ai/core/go-inference"
|
||||
"forge.lthn.ai/core/go-rocm/internal/gguf"
|
||||
)
|
||||
|
||||
// rocmBackend implements inference.Backend for AMD ROCm GPUs.
|
||||
|
|
@ -28,6 +28,9 @@ func (b *rocmBackend) Available() bool {
|
|||
}
|
||||
|
||||
// LoadModel loads a GGUF model onto the AMD GPU via llama-server.
|
||||
// Model architecture is read from GGUF metadata (replacing filename-based guessing).
|
||||
// If no context length is specified, defaults to min(model_context_length, 4096)
|
||||
// to prevent VRAM exhaustion on models with 128K+ native context.
|
||||
func (b *rocmBackend) LoadModel(path string, opts ...inference.LoadOption) (inference.TextModel, error) {
|
||||
cfg := inference.ApplyLoadOpts(opts)
|
||||
|
||||
|
|
@ -36,26 +39,23 @@ func (b *rocmBackend) LoadModel(path string, opts ...inference.LoadOption) (infe
|
|||
return nil, err
|
||||
}
|
||||
|
||||
srv, err := startServer(binary, path, cfg.GPULayers, cfg.ContextLen)
|
||||
meta, err := gguf.ReadMetadata(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("rocm: read model metadata: %w", err)
|
||||
}
|
||||
|
||||
ctxLen := cfg.ContextLen
|
||||
if ctxLen == 0 && meta.ContextLength > 0 {
|
||||
ctxLen = int(min(meta.ContextLength, 4096))
|
||||
}
|
||||
|
||||
srv, err := startServer(binary, path, cfg.GPULayers, ctxLen)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &rocmModel{
|
||||
srv: srv,
|
||||
modelType: guessModelType(path),
|
||||
modelType: meta.Architecture,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// guessModelType extracts a model type hint from the filename.
|
||||
// Hyphens are stripped so that "Llama-3" matches "llama3".
|
||||
func guessModelType(path string) string {
|
||||
name := strings.ToLower(filepath.Base(path))
|
||||
name = strings.ReplaceAll(name, "-", "")
|
||||
for _, arch := range []string{"gemma3", "gemma2", "gemma", "qwen3", "qwen2", "qwen", "llama3", "llama", "mistral", "phi"} {
|
||||
if strings.Contains(name, arch) {
|
||||
return arch
|
||||
}
|
||||
}
|
||||
return "unknown"
|
||||
}
|
||||
|
|
|
|||
|
|
@ -81,23 +81,6 @@ func TestAvailable(t *testing.T) {
|
|||
assert.True(t, b.Available())
|
||||
}
|
||||
|
||||
func TestGuessModelType(t *testing.T) {
|
||||
tests := []struct {
|
||||
path string
|
||||
expected string
|
||||
}{
|
||||
{"/data/lem/gguf/LEK-Gemma3-4B-Q4_K_M.gguf", "gemma3"},
|
||||
{"/data/models/Qwen3-8B-Q4_K_M.gguf", "qwen3"},
|
||||
{"/data/models/Llama-3.1-8B-Q4_K_M.gguf", "llama3"},
|
||||
{"/data/models/Mistral-7B-v0.3-Q4_K_M.gguf", "mistral"},
|
||||
{"/data/models/random-model.gguf", "unknown"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.expected, func(t *testing.T) {
|
||||
assert.Equal(t, tt.expected, guessModelType(tt.path))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestServerAlive_Running(t *testing.T) {
|
||||
s := &server{exited: make(chan struct{})}
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue