Replace fmt/strings/path/filepath/encoding/json with core equivalents throughout
all packages. Rename cfg→configuration, srv→server/subprocess, ftName→fileTypeName,
ctxSize→contextSize. Add usage-example doc-comments to every exported symbol.
Update all test names to TestSubject_Function_{Good,Bad,Ugly} convention.
Co-Authored-By: Virgil <virgil@lethean.io>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
303 lines
8.4 KiB
Go
303 lines
8.4 KiB
Go
//go:build linux && amd64
|
|
|
|
package rocm
|
|
|
|
import (
|
|
"context"
|
|
"iter"
|
|
"sync"
|
|
"time"
|
|
|
|
"dappco.re/go/core"
|
|
|
|
coreerr "forge.lthn.ai/core/go-log"
|
|
"forge.lthn.ai/core/go-inference"
|
|
"forge.lthn.ai/core/go-rocm/internal/llamacpp"
|
|
)
|
|
|
|
// rocmModel implements inference.TextModel using a llama-server subprocess.
|
|
type rocmModel struct {
|
|
server *server
|
|
modelType string
|
|
modelInfo inference.ModelInfo
|
|
|
|
mu sync.Mutex
|
|
lastErr error
|
|
metrics inference.GenerateMetrics
|
|
}
|
|
|
|
// Generate streams tokens for the given prompt via llama-server's /v1/completions endpoint.
|
|
//
|
|
// for tok := range m.Generate(ctx, "The capital of France is", inference.WithMaxTokens(32)) {
|
|
// output += tok.Text
|
|
// }
|
|
func (m *rocmModel) Generate(ctx context.Context, prompt string, opts ...inference.GenerateOption) iter.Seq[inference.Token] {
|
|
m.mu.Lock()
|
|
m.lastErr = nil
|
|
m.mu.Unlock()
|
|
|
|
if !m.server.alive() {
|
|
m.setServerExitErr()
|
|
return func(yield func(inference.Token) bool) {}
|
|
}
|
|
|
|
configuration := inference.ApplyGenerateOpts(opts)
|
|
|
|
completionRequest := llamacpp.CompletionRequest{
|
|
Prompt: prompt,
|
|
MaxTokens: configuration.MaxTokens,
|
|
Temperature: configuration.Temperature,
|
|
TopK: configuration.TopK,
|
|
TopP: configuration.TopP,
|
|
RepeatPenalty: configuration.RepeatPenalty,
|
|
}
|
|
|
|
start := time.Now()
|
|
chunks, errorFunc := m.server.client.Complete(ctx, completionRequest)
|
|
|
|
return func(yield func(inference.Token) bool) {
|
|
var tokenCount int
|
|
decodeStart := time.Now()
|
|
for text := range chunks {
|
|
tokenCount++
|
|
if !yield(inference.Token{Text: text}) {
|
|
break
|
|
}
|
|
}
|
|
if err := errorFunc(); err != nil {
|
|
m.mu.Lock()
|
|
m.lastErr = err
|
|
m.mu.Unlock()
|
|
}
|
|
m.recordMetrics(0, tokenCount, start, decodeStart)
|
|
}
|
|
}
|
|
|
|
// Chat streams tokens from a multi-turn conversation via llama-server's /v1/chat/completions endpoint.
|
|
//
|
|
// msgs := []inference.Message{{Role: "user", Content: "Hello"}}
|
|
// for tok := range m.Chat(ctx, msgs, inference.WithMaxTokens(64)) {
|
|
// output += tok.Text
|
|
// }
|
|
func (m *rocmModel) Chat(ctx context.Context, messages []inference.Message, opts ...inference.GenerateOption) iter.Seq[inference.Token] {
|
|
m.mu.Lock()
|
|
m.lastErr = nil
|
|
m.mu.Unlock()
|
|
|
|
if !m.server.alive() {
|
|
m.setServerExitErr()
|
|
return func(yield func(inference.Token) bool) {}
|
|
}
|
|
|
|
configuration := inference.ApplyGenerateOpts(opts)
|
|
|
|
chatMessages := make([]llamacpp.ChatMessage, len(messages))
|
|
for i, msg := range messages {
|
|
chatMessages[i] = llamacpp.ChatMessage{
|
|
Role: msg.Role,
|
|
Content: msg.Content,
|
|
}
|
|
}
|
|
|
|
chatRequest := llamacpp.ChatRequest{
|
|
Messages: chatMessages,
|
|
MaxTokens: configuration.MaxTokens,
|
|
Temperature: configuration.Temperature,
|
|
TopK: configuration.TopK,
|
|
TopP: configuration.TopP,
|
|
RepeatPenalty: configuration.RepeatPenalty,
|
|
}
|
|
|
|
start := time.Now()
|
|
chunks, errorFunc := m.server.client.ChatComplete(ctx, chatRequest)
|
|
|
|
return func(yield func(inference.Token) bool) {
|
|
var tokenCount int
|
|
decodeStart := time.Now()
|
|
for text := range chunks {
|
|
tokenCount++
|
|
if !yield(inference.Token{Text: text}) {
|
|
break
|
|
}
|
|
}
|
|
if err := errorFunc(); err != nil {
|
|
m.mu.Lock()
|
|
m.lastErr = err
|
|
m.mu.Unlock()
|
|
}
|
|
m.recordMetrics(0, tokenCount, start, decodeStart)
|
|
}
|
|
}
|
|
|
|
// Classify runs batched prefill-only inference via llama-server.
|
|
// Each prompt gets a single-token completion (max_tokens=1, temperature=0).
|
|
// llama-server has no native classify endpoint, so this simulates it.
|
|
//
|
|
// results, err := m.Classify(ctx, []string{"positive review", "negative review"})
|
|
// // results[0].Token.Text == "pos" or similar top token
|
|
func (m *rocmModel) Classify(ctx context.Context, prompts []string, opts ...inference.GenerateOption) ([]inference.ClassifyResult, error) {
|
|
if !m.server.alive() {
|
|
m.setServerExitErr()
|
|
return nil, m.Err()
|
|
}
|
|
|
|
start := time.Now()
|
|
results := make([]inference.ClassifyResult, len(prompts))
|
|
|
|
for i, prompt := range prompts {
|
|
if ctx.Err() != nil {
|
|
return nil, ctx.Err()
|
|
}
|
|
|
|
completionRequest := llamacpp.CompletionRequest{
|
|
Prompt: prompt,
|
|
MaxTokens: 1,
|
|
Temperature: 0,
|
|
}
|
|
|
|
chunks, errorFunc := m.server.client.Complete(ctx, completionRequest)
|
|
builder := core.NewBuilder()
|
|
for chunk := range chunks {
|
|
builder.WriteString(chunk)
|
|
}
|
|
if err := errorFunc(); err != nil {
|
|
return nil, coreerr.E("rocm.Classify", core.Sprintf("classify prompt %d", i), err)
|
|
}
|
|
|
|
results[i] = inference.ClassifyResult{
|
|
Token: inference.Token{Text: builder.String()},
|
|
}
|
|
}
|
|
|
|
m.recordMetrics(len(prompts), len(prompts), start, start)
|
|
return results, nil
|
|
}
|
|
|
|
// BatchGenerate runs batched autoregressive generation via llama-server.
|
|
// Each prompt is decoded sequentially up to MaxTokens.
|
|
//
|
|
// results, err := m.BatchGenerate(ctx, []string{"prompt A", "prompt B"}, inference.WithMaxTokens(64))
|
|
// // results[0].Tokens — tokens generated for prompt A
|
|
func (m *rocmModel) BatchGenerate(ctx context.Context, prompts []string, opts ...inference.GenerateOption) ([]inference.BatchResult, error) {
|
|
if !m.server.alive() {
|
|
m.setServerExitErr()
|
|
return nil, m.Err()
|
|
}
|
|
|
|
configuration := inference.ApplyGenerateOpts(opts)
|
|
start := time.Now()
|
|
results := make([]inference.BatchResult, len(prompts))
|
|
var totalGenerated int
|
|
|
|
for i, prompt := range prompts {
|
|
if ctx.Err() != nil {
|
|
results[i].Err = ctx.Err()
|
|
continue
|
|
}
|
|
|
|
completionRequest := llamacpp.CompletionRequest{
|
|
Prompt: prompt,
|
|
MaxTokens: configuration.MaxTokens,
|
|
Temperature: configuration.Temperature,
|
|
TopK: configuration.TopK,
|
|
TopP: configuration.TopP,
|
|
RepeatPenalty: configuration.RepeatPenalty,
|
|
}
|
|
|
|
chunks, errorFunc := m.server.client.Complete(ctx, completionRequest)
|
|
var tokens []inference.Token
|
|
for text := range chunks {
|
|
tokens = append(tokens, inference.Token{Text: text})
|
|
}
|
|
if err := errorFunc(); err != nil {
|
|
results[i].Err = coreerr.E("rocm.BatchGenerate", core.Sprintf("batch prompt %d", i), err)
|
|
}
|
|
results[i].Tokens = tokens
|
|
totalGenerated += len(tokens)
|
|
}
|
|
|
|
m.recordMetrics(len(prompts), totalGenerated, start, start)
|
|
return results, nil
|
|
}
|
|
|
|
// ModelType returns the architecture identifier (e.g. "gemma3", "qwen3", "llama3").
|
|
//
|
|
// arch := m.ModelType() // "gemma3"
|
|
func (m *rocmModel) ModelType() string { return m.modelType }
|
|
|
|
// Info returns metadata about the loaded model.
|
|
//
|
|
// info := m.Info()
|
|
// // info.Architecture == "gemma3", info.NumLayers == 26
|
|
func (m *rocmModel) Info() inference.ModelInfo { return m.modelInfo }
|
|
|
|
// Metrics returns performance metrics from the last inference operation.
|
|
//
|
|
// metrics := m.Metrics()
|
|
// // metrics.DecodeTokensPerSec, metrics.TotalDuration
|
|
func (m *rocmModel) Metrics() inference.GenerateMetrics {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
return m.metrics
|
|
}
|
|
|
|
// Err returns the error from the last Generate/Chat call, if any.
|
|
//
|
|
// for tok := range m.Generate(ctx, prompt) { }
|
|
// if err := m.Err(); err != nil { /* handle */ }
|
|
func (m *rocmModel) Err() error {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
return m.lastErr
|
|
}
|
|
|
|
// Close releases the llama-server subprocess and all associated resources.
|
|
//
|
|
// m, err := backend.LoadModel("/data/model.gguf")
|
|
// defer m.Close()
|
|
func (m *rocmModel) Close() error {
|
|
return m.server.stop()
|
|
}
|
|
|
|
// setServerExitErr stores an appropriate error when the server is dead.
|
|
func (m *rocmModel) setServerExitErr() {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
if m.server.exitErr != nil {
|
|
m.lastErr = coreerr.E("rocm.setServerExitErr", "server has exited", m.server.exitErr)
|
|
} else {
|
|
m.lastErr = coreerr.E("rocm.setServerExitErr", "server has exited unexpectedly", nil)
|
|
}
|
|
}
|
|
|
|
// recordMetrics captures timing data from an inference operation.
|
|
func (m *rocmModel) recordMetrics(promptTokens, generatedTokens int, start, decodeStart time.Time) {
|
|
now := time.Now()
|
|
total := now.Sub(start)
|
|
decode := now.Sub(decodeStart)
|
|
prefill := total - decode
|
|
|
|
result := inference.GenerateMetrics{
|
|
PromptTokens: promptTokens,
|
|
GeneratedTokens: generatedTokens,
|
|
PrefillDuration: prefill,
|
|
DecodeDuration: decode,
|
|
TotalDuration: total,
|
|
}
|
|
if prefill > 0 && promptTokens > 0 {
|
|
result.PrefillTokensPerSec = float64(promptTokens) / prefill.Seconds()
|
|
}
|
|
if decode > 0 && generatedTokens > 0 {
|
|
result.DecodeTokensPerSec = float64(generatedTokens) / decode.Seconds()
|
|
}
|
|
|
|
// Try to get VRAM stats — best effort.
|
|
if vramInfo, err := GetVRAMInfo(); err == nil {
|
|
result.PeakMemoryBytes = vramInfo.Used
|
|
result.ActiveMemoryBytes = vramInfo.Used
|
|
}
|
|
|
|
m.mu.Lock()
|
|
m.metrics = result
|
|
m.mu.Unlock()
|
|
}
|