feat: implement Classify, BatchGenerate, Info, Metrics on rocmModel
Brings rocmModel into compliance with the updated inference.TextModel interface from go-inference. - Classify: simulates prefill-only via max_tokens=1, temperature=0 - BatchGenerate: sequential autoregressive per prompt via /v1/completions - Info: populates ModelInfo from GGUF metadata (architecture, layers, quant) - Metrics: captures timing + VRAM usage via sysfs after each operation - Refactors duplicate server-exit error handling into setServerExitErr() - Adds timing instrumentation to existing Generate and Chat methods Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
parent
197c537e9f
commit
b03f357f5d
3 changed files with 194 additions and 14 deletions
36
backend.go
36
backend.go
|
|
@ -5,6 +5,7 @@ package rocm
|
|||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"forge.lthn.ai/core/go-inference"
|
||||
"forge.lthn.ai/core/go-rocm/internal/gguf"
|
||||
|
|
@ -54,8 +55,43 @@ func (b *rocmBackend) LoadModel(path string, opts ...inference.LoadOption) (infe
|
|||
return nil, err
|
||||
}
|
||||
|
||||
// Map quantisation file type to bit width.
|
||||
quantBits := 0
|
||||
quantGroup := 0
|
||||
ftName := gguf.FileTypeName(meta.FileType)
|
||||
switch {
|
||||
case strings.HasPrefix(ftName, "Q4_"):
|
||||
quantBits = 4
|
||||
quantGroup = 32
|
||||
case strings.HasPrefix(ftName, "Q5_"):
|
||||
quantBits = 5
|
||||
quantGroup = 32
|
||||
case strings.HasPrefix(ftName, "Q8_"):
|
||||
quantBits = 8
|
||||
quantGroup = 32
|
||||
case strings.HasPrefix(ftName, "Q2_"):
|
||||
quantBits = 2
|
||||
quantGroup = 16
|
||||
case strings.HasPrefix(ftName, "Q3_"):
|
||||
quantBits = 3
|
||||
quantGroup = 32
|
||||
case strings.HasPrefix(ftName, "Q6_"):
|
||||
quantBits = 6
|
||||
quantGroup = 64
|
||||
case ftName == "F16":
|
||||
quantBits = 16
|
||||
case ftName == "F32":
|
||||
quantBits = 32
|
||||
}
|
||||
|
||||
return &rocmModel{
|
||||
srv: srv,
|
||||
modelType: meta.Architecture,
|
||||
modelInfo: inference.ModelInfo{
|
||||
Architecture: meta.Architecture,
|
||||
NumLayers: int(meta.BlockCount),
|
||||
QuantBits: quantBits,
|
||||
QuantGroup: quantGroup,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
|
|
|||
4
go.sum
4
go.sum
|
|
@ -1,7 +1,11 @@
|
|||
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
|
||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
|
||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ=
|
||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||
|
|
|
|||
168
model.go
168
model.go
|
|
@ -6,7 +6,9 @@ import (
|
|||
"context"
|
||||
"fmt"
|
||||
"iter"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"forge.lthn.ai/core/go-inference"
|
||||
"forge.lthn.ai/core/go-rocm/internal/llamacpp"
|
||||
|
|
@ -16,9 +18,11 @@ import (
|
|||
type rocmModel struct {
|
||||
srv *server
|
||||
modelType string
|
||||
modelInfo inference.ModelInfo
|
||||
|
||||
mu sync.Mutex
|
||||
lastErr error
|
||||
metrics inference.GenerateMetrics
|
||||
}
|
||||
|
||||
// Generate streams tokens for the given prompt via llama-server's /v1/completions endpoint.
|
||||
|
|
@ -28,13 +32,7 @@ func (m *rocmModel) Generate(ctx context.Context, prompt string, opts ...inferen
|
|||
m.mu.Unlock()
|
||||
|
||||
if !m.srv.alive() {
|
||||
m.mu.Lock()
|
||||
if m.srv.exitErr != nil {
|
||||
m.lastErr = fmt.Errorf("rocm: server has exited: %w", m.srv.exitErr)
|
||||
} else {
|
||||
m.lastErr = fmt.Errorf("rocm: server has exited unexpectedly")
|
||||
}
|
||||
m.mu.Unlock()
|
||||
m.setServerExitErr()
|
||||
return func(yield func(inference.Token) bool) {}
|
||||
}
|
||||
|
||||
|
|
@ -49,10 +47,14 @@ func (m *rocmModel) Generate(ctx context.Context, prompt string, opts ...inferen
|
|||
RepeatPenalty: cfg.RepeatPenalty,
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
chunks, errFn := m.srv.client.Complete(ctx, req)
|
||||
|
||||
return func(yield func(inference.Token) bool) {
|
||||
var count int
|
||||
decodeStart := time.Now()
|
||||
for text := range chunks {
|
||||
count++
|
||||
if !yield(inference.Token{Text: text}) {
|
||||
break
|
||||
}
|
||||
|
|
@ -62,6 +64,7 @@ func (m *rocmModel) Generate(ctx context.Context, prompt string, opts ...inferen
|
|||
m.lastErr = err
|
||||
m.mu.Unlock()
|
||||
}
|
||||
m.recordMetrics(0, count, start, decodeStart)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -72,13 +75,7 @@ func (m *rocmModel) Chat(ctx context.Context, messages []inference.Message, opts
|
|||
m.mu.Unlock()
|
||||
|
||||
if !m.srv.alive() {
|
||||
m.mu.Lock()
|
||||
if m.srv.exitErr != nil {
|
||||
m.lastErr = fmt.Errorf("rocm: server has exited: %w", m.srv.exitErr)
|
||||
} else {
|
||||
m.lastErr = fmt.Errorf("rocm: server has exited unexpectedly")
|
||||
}
|
||||
m.mu.Unlock()
|
||||
m.setServerExitErr()
|
||||
return func(yield func(inference.Token) bool) {}
|
||||
}
|
||||
|
||||
|
|
@ -101,10 +98,14 @@ func (m *rocmModel) Chat(ctx context.Context, messages []inference.Message, opts
|
|||
RepeatPenalty: cfg.RepeatPenalty,
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
chunks, errFn := m.srv.client.ChatComplete(ctx, req)
|
||||
|
||||
return func(yield func(inference.Token) bool) {
|
||||
var count int
|
||||
decodeStart := time.Now()
|
||||
for text := range chunks {
|
||||
count++
|
||||
if !yield(inference.Token{Text: text}) {
|
||||
break
|
||||
}
|
||||
|
|
@ -114,12 +115,108 @@ func (m *rocmModel) Chat(ctx context.Context, messages []inference.Message, opts
|
|||
m.lastErr = err
|
||||
m.mu.Unlock()
|
||||
}
|
||||
m.recordMetrics(0, count, start, decodeStart)
|
||||
}
|
||||
}
|
||||
|
||||
// Classify runs batched prefill-only inference via llama-server.
|
||||
// Each prompt gets a single-token completion (max_tokens=1, temperature=0).
|
||||
// llama-server has no native classify endpoint, so this simulates it.
|
||||
func (m *rocmModel) Classify(ctx context.Context, prompts []string, opts ...inference.GenerateOption) ([]inference.ClassifyResult, error) {
|
||||
if !m.srv.alive() {
|
||||
m.setServerExitErr()
|
||||
return nil, m.Err()
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
results := make([]inference.ClassifyResult, len(prompts))
|
||||
|
||||
for i, prompt := range prompts {
|
||||
if ctx.Err() != nil {
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
|
||||
req := llamacpp.CompletionRequest{
|
||||
Prompt: prompt,
|
||||
MaxTokens: 1,
|
||||
Temperature: 0,
|
||||
}
|
||||
|
||||
chunks, errFn := m.srv.client.Complete(ctx, req)
|
||||
var text strings.Builder
|
||||
for chunk := range chunks {
|
||||
text.WriteString(chunk)
|
||||
}
|
||||
if err := errFn(); err != nil {
|
||||
return nil, fmt.Errorf("rocm: classify prompt %d: %w", i, err)
|
||||
}
|
||||
|
||||
results[i] = inference.ClassifyResult{
|
||||
Token: inference.Token{Text: text.String()},
|
||||
}
|
||||
}
|
||||
|
||||
m.recordMetrics(len(prompts), len(prompts), start, start)
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// BatchGenerate runs batched autoregressive generation via llama-server.
|
||||
// Each prompt is decoded sequentially up to MaxTokens.
|
||||
func (m *rocmModel) BatchGenerate(ctx context.Context, prompts []string, opts ...inference.GenerateOption) ([]inference.BatchResult, error) {
|
||||
if !m.srv.alive() {
|
||||
m.setServerExitErr()
|
||||
return nil, m.Err()
|
||||
}
|
||||
|
||||
cfg := inference.ApplyGenerateOpts(opts)
|
||||
start := time.Now()
|
||||
results := make([]inference.BatchResult, len(prompts))
|
||||
var totalGenerated int
|
||||
|
||||
for i, prompt := range prompts {
|
||||
if ctx.Err() != nil {
|
||||
results[i].Err = ctx.Err()
|
||||
continue
|
||||
}
|
||||
|
||||
req := llamacpp.CompletionRequest{
|
||||
Prompt: prompt,
|
||||
MaxTokens: cfg.MaxTokens,
|
||||
Temperature: cfg.Temperature,
|
||||
TopK: cfg.TopK,
|
||||
TopP: cfg.TopP,
|
||||
RepeatPenalty: cfg.RepeatPenalty,
|
||||
}
|
||||
|
||||
chunks, errFn := m.srv.client.Complete(ctx, req)
|
||||
var tokens []inference.Token
|
||||
for text := range chunks {
|
||||
tokens = append(tokens, inference.Token{Text: text})
|
||||
}
|
||||
if err := errFn(); err != nil {
|
||||
results[i].Err = fmt.Errorf("rocm: batch prompt %d: %w", i, err)
|
||||
}
|
||||
results[i].Tokens = tokens
|
||||
totalGenerated += len(tokens)
|
||||
}
|
||||
|
||||
m.recordMetrics(len(prompts), totalGenerated, start, start)
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// ModelType returns the architecture identifier (e.g. "gemma3", "qwen3", "llama3").
|
||||
func (m *rocmModel) ModelType() string { return m.modelType }
|
||||
|
||||
// Info returns metadata about the loaded model.
|
||||
func (m *rocmModel) Info() inference.ModelInfo { return m.modelInfo }
|
||||
|
||||
// Metrics returns performance metrics from the last inference operation.
|
||||
func (m *rocmModel) Metrics() inference.GenerateMetrics {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
return m.metrics
|
||||
}
|
||||
|
||||
// Err returns the error from the last Generate/Chat call, if any.
|
||||
func (m *rocmModel) Err() error {
|
||||
m.mu.Lock()
|
||||
|
|
@ -131,3 +228,46 @@ func (m *rocmModel) Err() error {
|
|||
func (m *rocmModel) Close() error {
|
||||
return m.srv.stop()
|
||||
}
|
||||
|
||||
// setServerExitErr stores an appropriate error when the server is dead.
|
||||
func (m *rocmModel) setServerExitErr() {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
if m.srv.exitErr != nil {
|
||||
m.lastErr = fmt.Errorf("rocm: server has exited: %w", m.srv.exitErr)
|
||||
} else {
|
||||
m.lastErr = fmt.Errorf("rocm: server has exited unexpectedly")
|
||||
}
|
||||
}
|
||||
|
||||
// recordMetrics captures timing data from an inference operation.
|
||||
func (m *rocmModel) recordMetrics(promptTokens, generatedTokens int, start, decodeStart time.Time) {
|
||||
now := time.Now()
|
||||
total := now.Sub(start)
|
||||
decode := now.Sub(decodeStart)
|
||||
prefill := total - decode
|
||||
|
||||
met := inference.GenerateMetrics{
|
||||
PromptTokens: promptTokens,
|
||||
GeneratedTokens: generatedTokens,
|
||||
PrefillDuration: prefill,
|
||||
DecodeDuration: decode,
|
||||
TotalDuration: total,
|
||||
}
|
||||
if prefill > 0 && promptTokens > 0 {
|
||||
met.PrefillTokensPerSec = float64(promptTokens) / prefill.Seconds()
|
||||
}
|
||||
if decode > 0 && generatedTokens > 0 {
|
||||
met.DecodeTokensPerSec = float64(generatedTokens) / decode.Seconds()
|
||||
}
|
||||
|
||||
// Try to get VRAM stats — best effort.
|
||||
if vram, err := GetVRAMInfo(); err == nil {
|
||||
met.PeakMemoryBytes = vram.Used
|
||||
met.ActiveMemoryBytes = vram.Used
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
m.metrics = met
|
||||
m.mu.Unlock()
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue