feat: implement Classify, BatchGenerate, Info, Metrics on rocmModel
Some checks failed
Security Scan / security (push) Successful in 10s
Test / Vet & Build (push) Failing after 34s

Brings rocmModel into compliance with the updated inference.TextModel
interface from go-inference.

- Classify: simulates prefill-only via max_tokens=1, temperature=0
- BatchGenerate: sequential autoregressive per prompt via /v1/completions
- Info: populates ModelInfo from GGUF metadata (architecture, layers, quant)
- Metrics: captures timing + VRAM usage via sysfs after each operation
- Refactors duplicate server-exit error handling into setServerExitErr()
- Adds timing instrumentation to existing Generate and Chat methods

Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
Claude 2026-02-24 18:50:37 +00:00
parent 197c537e9f
commit b03f357f5d
No known key found for this signature in database
GPG key ID: AF404715446AEB41
3 changed files with 194 additions and 14 deletions

View file

@ -5,6 +5,7 @@ package rocm
import (
"fmt"
"os"
"strings"
"forge.lthn.ai/core/go-inference"
"forge.lthn.ai/core/go-rocm/internal/gguf"
@ -54,8 +55,43 @@ func (b *rocmBackend) LoadModel(path string, opts ...inference.LoadOption) (infe
return nil, err
}
// Map quantisation file type to bit width.
quantBits := 0
quantGroup := 0
ftName := gguf.FileTypeName(meta.FileType)
switch {
case strings.HasPrefix(ftName, "Q4_"):
quantBits = 4
quantGroup = 32
case strings.HasPrefix(ftName, "Q5_"):
quantBits = 5
quantGroup = 32
case strings.HasPrefix(ftName, "Q8_"):
quantBits = 8
quantGroup = 32
case strings.HasPrefix(ftName, "Q2_"):
quantBits = 2
quantGroup = 16
case strings.HasPrefix(ftName, "Q3_"):
quantBits = 3
quantGroup = 32
case strings.HasPrefix(ftName, "Q6_"):
quantBits = 6
quantGroup = 64
case ftName == "F16":
quantBits = 16
case ftName == "F32":
quantBits = 32
}
return &rocmModel{
srv: srv,
modelType: meta.Architecture,
modelInfo: inference.ModelInfo{
Architecture: meta.Architecture,
NumLayers: int(meta.BlockCount),
QuantBits: quantBits,
QuantGroup: quantGroup,
},
}, nil
}

4
go.sum
View file

@ -1,7 +1,11 @@
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=

168
model.go
View file

@ -6,7 +6,9 @@ import (
"context"
"fmt"
"iter"
"strings"
"sync"
"time"
"forge.lthn.ai/core/go-inference"
"forge.lthn.ai/core/go-rocm/internal/llamacpp"
@ -16,9 +18,11 @@ import (
type rocmModel struct {
srv *server
modelType string
modelInfo inference.ModelInfo
mu sync.Mutex
lastErr error
metrics inference.GenerateMetrics
}
// Generate streams tokens for the given prompt via llama-server's /v1/completions endpoint.
@ -28,13 +32,7 @@ func (m *rocmModel) Generate(ctx context.Context, prompt string, opts ...inferen
m.mu.Unlock()
if !m.srv.alive() {
m.mu.Lock()
if m.srv.exitErr != nil {
m.lastErr = fmt.Errorf("rocm: server has exited: %w", m.srv.exitErr)
} else {
m.lastErr = fmt.Errorf("rocm: server has exited unexpectedly")
}
m.mu.Unlock()
m.setServerExitErr()
return func(yield func(inference.Token) bool) {}
}
@ -49,10 +47,14 @@ func (m *rocmModel) Generate(ctx context.Context, prompt string, opts ...inferen
RepeatPenalty: cfg.RepeatPenalty,
}
start := time.Now()
chunks, errFn := m.srv.client.Complete(ctx, req)
return func(yield func(inference.Token) bool) {
var count int
decodeStart := time.Now()
for text := range chunks {
count++
if !yield(inference.Token{Text: text}) {
break
}
@ -62,6 +64,7 @@ func (m *rocmModel) Generate(ctx context.Context, prompt string, opts ...inferen
m.lastErr = err
m.mu.Unlock()
}
m.recordMetrics(0, count, start, decodeStart)
}
}
@ -72,13 +75,7 @@ func (m *rocmModel) Chat(ctx context.Context, messages []inference.Message, opts
m.mu.Unlock()
if !m.srv.alive() {
m.mu.Lock()
if m.srv.exitErr != nil {
m.lastErr = fmt.Errorf("rocm: server has exited: %w", m.srv.exitErr)
} else {
m.lastErr = fmt.Errorf("rocm: server has exited unexpectedly")
}
m.mu.Unlock()
m.setServerExitErr()
return func(yield func(inference.Token) bool) {}
}
@ -101,10 +98,14 @@ func (m *rocmModel) Chat(ctx context.Context, messages []inference.Message, opts
RepeatPenalty: cfg.RepeatPenalty,
}
start := time.Now()
chunks, errFn := m.srv.client.ChatComplete(ctx, req)
return func(yield func(inference.Token) bool) {
var count int
decodeStart := time.Now()
for text := range chunks {
count++
if !yield(inference.Token{Text: text}) {
break
}
@ -114,12 +115,108 @@ func (m *rocmModel) Chat(ctx context.Context, messages []inference.Message, opts
m.lastErr = err
m.mu.Unlock()
}
m.recordMetrics(0, count, start, decodeStart)
}
}
// Classify runs batched prefill-only inference via llama-server.
// Each prompt gets a single-token completion (max_tokens=1, temperature=0).
// llama-server has no native classify endpoint, so this simulates it.
func (m *rocmModel) Classify(ctx context.Context, prompts []string, opts ...inference.GenerateOption) ([]inference.ClassifyResult, error) {
if !m.srv.alive() {
m.setServerExitErr()
return nil, m.Err()
}
start := time.Now()
results := make([]inference.ClassifyResult, len(prompts))
for i, prompt := range prompts {
if ctx.Err() != nil {
return nil, ctx.Err()
}
req := llamacpp.CompletionRequest{
Prompt: prompt,
MaxTokens: 1,
Temperature: 0,
}
chunks, errFn := m.srv.client.Complete(ctx, req)
var text strings.Builder
for chunk := range chunks {
text.WriteString(chunk)
}
if err := errFn(); err != nil {
return nil, fmt.Errorf("rocm: classify prompt %d: %w", i, err)
}
results[i] = inference.ClassifyResult{
Token: inference.Token{Text: text.String()},
}
}
m.recordMetrics(len(prompts), len(prompts), start, start)
return results, nil
}
// BatchGenerate runs batched autoregressive generation via llama-server.
// Each prompt is decoded sequentially up to MaxTokens.
func (m *rocmModel) BatchGenerate(ctx context.Context, prompts []string, opts ...inference.GenerateOption) ([]inference.BatchResult, error) {
if !m.srv.alive() {
m.setServerExitErr()
return nil, m.Err()
}
cfg := inference.ApplyGenerateOpts(opts)
start := time.Now()
results := make([]inference.BatchResult, len(prompts))
var totalGenerated int
for i, prompt := range prompts {
if ctx.Err() != nil {
results[i].Err = ctx.Err()
continue
}
req := llamacpp.CompletionRequest{
Prompt: prompt,
MaxTokens: cfg.MaxTokens,
Temperature: cfg.Temperature,
TopK: cfg.TopK,
TopP: cfg.TopP,
RepeatPenalty: cfg.RepeatPenalty,
}
chunks, errFn := m.srv.client.Complete(ctx, req)
var tokens []inference.Token
for text := range chunks {
tokens = append(tokens, inference.Token{Text: text})
}
if err := errFn(); err != nil {
results[i].Err = fmt.Errorf("rocm: batch prompt %d: %w", i, err)
}
results[i].Tokens = tokens
totalGenerated += len(tokens)
}
m.recordMetrics(len(prompts), totalGenerated, start, start)
return results, nil
}
// ModelType returns the architecture identifier (e.g. "gemma3", "qwen3", "llama3").
func (m *rocmModel) ModelType() string { return m.modelType }
// Info returns metadata about the loaded model.
func (m *rocmModel) Info() inference.ModelInfo { return m.modelInfo }
// Metrics returns performance metrics from the last inference operation.
func (m *rocmModel) Metrics() inference.GenerateMetrics {
m.mu.Lock()
defer m.mu.Unlock()
return m.metrics
}
// Err returns the error from the last Generate/Chat call, if any.
func (m *rocmModel) Err() error {
m.mu.Lock()
@ -131,3 +228,46 @@ func (m *rocmModel) Err() error {
func (m *rocmModel) Close() error {
return m.srv.stop()
}
// setServerExitErr stores an appropriate error when the server is dead.
func (m *rocmModel) setServerExitErr() {
m.mu.Lock()
defer m.mu.Unlock()
if m.srv.exitErr != nil {
m.lastErr = fmt.Errorf("rocm: server has exited: %w", m.srv.exitErr)
} else {
m.lastErr = fmt.Errorf("rocm: server has exited unexpectedly")
}
}
// recordMetrics captures timing data from an inference operation.
func (m *rocmModel) recordMetrics(promptTokens, generatedTokens int, start, decodeStart time.Time) {
now := time.Now()
total := now.Sub(start)
decode := now.Sub(decodeStart)
prefill := total - decode
met := inference.GenerateMetrics{
PromptTokens: promptTokens,
GeneratedTokens: generatedTokens,
PrefillDuration: prefill,
DecodeDuration: decode,
TotalDuration: total,
}
if prefill > 0 && promptTokens > 0 {
met.PrefillTokensPerSec = float64(promptTokens) / prefill.Seconds()
}
if decode > 0 && generatedTokens > 0 {
met.DecodeTokensPerSec = float64(generatedTokens) / decode.Seconds()
}
// Try to get VRAM stats — best effort.
if vram, err := GetVRAMInfo(); err == nil {
met.PeakMemoryBytes = vram.Used
met.ActiveMemoryBytes = vram.Used
}
m.mu.Lock()
m.metrics = met
m.mu.Unlock()
}