refactor: replace fmt.Errorf/errors.New with coreerr.E()
Some checks failed
Security Scan / security (push) Successful in 8s
Test / Vet & Build (push) Failing after 23s

Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
Snider 2026-03-16 21:08:52 +00:00
parent c0b7485129
commit 4669cc503d
10 changed files with 74 additions and 60 deletions

View file

@ -3,10 +3,10 @@
package rocm package rocm
import ( import (
"fmt"
"os" "os"
"strings" "strings"
coreerr "forge.lthn.ai/core/go-log"
"forge.lthn.ai/core/go-inference" "forge.lthn.ai/core/go-inference"
"forge.lthn.ai/core/go-rocm/internal/gguf" "forge.lthn.ai/core/go-rocm/internal/gguf"
) )
@ -42,7 +42,7 @@ func (b *rocmBackend) LoadModel(path string, opts ...inference.LoadOption) (infe
meta, err := gguf.ReadMetadata(path) meta, err := gguf.ReadMetadata(path)
if err != nil { if err != nil {
return nil, fmt.Errorf("rocm: read model metadata: %w", err) return nil, coreerr.E("rocm.LoadModel", "read model metadata", err)
} }
ctxLen := cfg.ContextLen ctxLen := cfg.ContextLen

5
go.mod
View file

@ -4,7 +4,10 @@ go 1.26.0
require forge.lthn.ai/core/go-inference v0.0.0 require forge.lthn.ai/core/go-inference v0.0.0
require github.com/kr/text v0.2.0 // indirect require (
forge.lthn.ai/core/go-log v0.0.4 // indirect
github.com/kr/text v0.2.0 // indirect
)
require ( require (
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect

2
go.sum
View file

@ -1,3 +1,5 @@
forge.lthn.ai/core/go-log v0.0.4 h1:KTuCEPgFmuM8KJfnyQ8vPOU1Jg654W74h8IJvfQMfv0=
forge.lthn.ai/core/go-log v0.0.4/go.mod h1:r14MXKOD3LF/sI8XUJQhRk/SZHBE7jAFVuCfgkXoZPw=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=

View file

@ -15,6 +15,8 @@ import (
"math" "math"
"os" "os"
"strings" "strings"
coreerr "forge.lthn.ai/core/go-log"
) )
// ggufMagic is the GGUF file magic number: "GGUF" in little-endian. // ggufMagic is the GGUF file magic number: "GGUF" in little-endian.
@ -96,37 +98,37 @@ func ReadMetadata(path string) (Metadata, error) {
// Read and validate magic number. // Read and validate magic number.
var magic uint32 var magic uint32
if err := binary.Read(r, binary.LittleEndian, &magic); err != nil { if err := binary.Read(r, binary.LittleEndian, &magic); err != nil {
return Metadata{}, fmt.Errorf("reading magic: %w", err) return Metadata{}, coreerr.E("gguf.ReadMetadata", "reading magic", err)
} }
if magic != ggufMagic { if magic != ggufMagic {
return Metadata{}, fmt.Errorf("invalid magic: 0x%08X (expected 0x%08X)", magic, ggufMagic) return Metadata{}, coreerr.E("gguf.ReadMetadata", fmt.Sprintf("invalid magic: 0x%08X (expected 0x%08X)", magic, ggufMagic), nil)
} }
// Read version. // Read version.
var version uint32 var version uint32
if err := binary.Read(r, binary.LittleEndian, &version); err != nil { if err := binary.Read(r, binary.LittleEndian, &version); err != nil {
return Metadata{}, fmt.Errorf("reading version: %w", err) return Metadata{}, coreerr.E("gguf.ReadMetadata", "reading version", err)
} }
if version < 2 || version > 3 { if version < 2 || version > 3 {
return Metadata{}, fmt.Errorf("unsupported GGUF version: %d", version) return Metadata{}, coreerr.E("gguf.ReadMetadata", fmt.Sprintf("unsupported GGUF version: %d", version), nil)
} }
// Read tensor count and KV count. v3 uses uint64, v2 uses uint32. // Read tensor count and KV count. v3 uses uint64, v2 uses uint32.
var tensorCount, kvCount uint64 var tensorCount, kvCount uint64
if version == 3 { if version == 3 {
if err := binary.Read(r, binary.LittleEndian, &tensorCount); err != nil { if err := binary.Read(r, binary.LittleEndian, &tensorCount); err != nil {
return Metadata{}, fmt.Errorf("reading tensor count: %w", err) return Metadata{}, coreerr.E("gguf.ReadMetadata", "reading tensor count", err)
} }
if err := binary.Read(r, binary.LittleEndian, &kvCount); err != nil { if err := binary.Read(r, binary.LittleEndian, &kvCount); err != nil {
return Metadata{}, fmt.Errorf("reading kv count: %w", err) return Metadata{}, coreerr.E("gguf.ReadMetadata", "reading kv count", err)
} }
} else { } else {
var tc, kc uint32 var tc, kc uint32
if err := binary.Read(r, binary.LittleEndian, &tc); err != nil { if err := binary.Read(r, binary.LittleEndian, &tc); err != nil {
return Metadata{}, fmt.Errorf("reading tensor count: %w", err) return Metadata{}, coreerr.E("gguf.ReadMetadata", "reading tensor count", err)
} }
if err := binary.Read(r, binary.LittleEndian, &kc); err != nil { if err := binary.Read(r, binary.LittleEndian, &kc); err != nil {
return Metadata{}, fmt.Errorf("reading kv count: %w", err) return Metadata{}, coreerr.E("gguf.ReadMetadata", "reading kv count", err)
} }
tensorCount = uint64(tc) tensorCount = uint64(tc)
kvCount = uint64(kc) kvCount = uint64(kc)
@ -148,12 +150,12 @@ func ReadMetadata(path string) (Metadata, error) {
for i := uint64(0); i < kvCount; i++ { for i := uint64(0); i < kvCount; i++ {
key, err := readString(r) key, err := readString(r)
if err != nil { if err != nil {
return Metadata{}, fmt.Errorf("reading key %d: %w", i, err) return Metadata{}, coreerr.E("gguf.ReadMetadata", fmt.Sprintf("reading key %d", i), err)
} }
var valType uint32 var valType uint32
if err := binary.Read(r, binary.LittleEndian, &valType); err != nil { if err := binary.Read(r, binary.LittleEndian, &valType); err != nil {
return Metadata{}, fmt.Errorf("reading value type for key %q: %w", key, err) return Metadata{}, coreerr.E("gguf.ReadMetadata", fmt.Sprintf("reading value type for key %q", key), err)
} }
// Check whether this is an interesting key before reading the value. // Check whether this is an interesting key before reading the value.
@ -161,7 +163,7 @@ func ReadMetadata(path string) (Metadata, error) {
case key == "general.architecture": case key == "general.architecture":
v, err := readTypedValue(r, valType) v, err := readTypedValue(r, valType)
if err != nil { if err != nil {
return Metadata{}, fmt.Errorf("reading value for key %q: %w", key, err) return Metadata{}, coreerr.E("gguf.ReadMetadata", fmt.Sprintf("reading value for key %q", key), err)
} }
if s, ok := v.(string); ok { if s, ok := v.(string); ok {
meta.Architecture = s meta.Architecture = s
@ -170,7 +172,7 @@ func ReadMetadata(path string) (Metadata, error) {
case key == "general.name": case key == "general.name":
v, err := readTypedValue(r, valType) v, err := readTypedValue(r, valType)
if err != nil { if err != nil {
return Metadata{}, fmt.Errorf("reading value for key %q: %w", key, err) return Metadata{}, coreerr.E("gguf.ReadMetadata", fmt.Sprintf("reading value for key %q", key), err)
} }
if s, ok := v.(string); ok { if s, ok := v.(string); ok {
meta.Name = s meta.Name = s
@ -179,7 +181,7 @@ func ReadMetadata(path string) (Metadata, error) {
case key == "general.file_type": case key == "general.file_type":
v, err := readTypedValue(r, valType) v, err := readTypedValue(r, valType)
if err != nil { if err != nil {
return Metadata{}, fmt.Errorf("reading value for key %q: %w", key, err) return Metadata{}, coreerr.E("gguf.ReadMetadata", fmt.Sprintf("reading value for key %q", key), err)
} }
if u, ok := v.(uint32); ok { if u, ok := v.(uint32); ok {
meta.FileType = u meta.FileType = u
@ -188,7 +190,7 @@ func ReadMetadata(path string) (Metadata, error) {
case key == "general.size_label": case key == "general.size_label":
v, err := readTypedValue(r, valType) v, err := readTypedValue(r, valType)
if err != nil { if err != nil {
return Metadata{}, fmt.Errorf("reading value for key %q: %w", key, err) return Metadata{}, coreerr.E("gguf.ReadMetadata", fmt.Sprintf("reading value for key %q", key), err)
} }
if s, ok := v.(string); ok { if s, ok := v.(string); ok {
meta.SizeLabel = s meta.SizeLabel = s
@ -197,7 +199,7 @@ func ReadMetadata(path string) (Metadata, error) {
case strings.HasSuffix(key, ".context_length"): case strings.HasSuffix(key, ".context_length"):
v, err := readTypedValue(r, valType) v, err := readTypedValue(r, valType)
if err != nil { if err != nil {
return Metadata{}, fmt.Errorf("reading value for key %q: %w", key, err) return Metadata{}, coreerr.E("gguf.ReadMetadata", fmt.Sprintf("reading value for key %q", key), err)
} }
if u, ok := v.(uint32); ok { if u, ok := v.(uint32); ok {
candidateContextLength[key] = u candidateContextLength[key] = u
@ -206,7 +208,7 @@ func ReadMetadata(path string) (Metadata, error) {
case strings.HasSuffix(key, ".block_count"): case strings.HasSuffix(key, ".block_count"):
v, err := readTypedValue(r, valType) v, err := readTypedValue(r, valType)
if err != nil { if err != nil {
return Metadata{}, fmt.Errorf("reading value for key %q: %w", key, err) return Metadata{}, coreerr.E("gguf.ReadMetadata", fmt.Sprintf("reading value for key %q", key), err)
} }
if u, ok := v.(uint32); ok { if u, ok := v.(uint32); ok {
candidateBlockCount[key] = u candidateBlockCount[key] = u
@ -215,7 +217,7 @@ func ReadMetadata(path string) (Metadata, error) {
default: default:
// Skip uninteresting value. // Skip uninteresting value.
if err := skipValue(r, valType); err != nil { if err := skipValue(r, valType); err != nil {
return Metadata{}, fmt.Errorf("skipping value for key %q: %w", key, err) return Metadata{}, coreerr.E("gguf.ReadMetadata", fmt.Sprintf("skipping value for key %q", key), err)
} }
} }
} }
@ -245,7 +247,7 @@ func readString(r io.Reader) (string, error) {
return "", err return "", err
} }
if length > maxStringLength { if length > maxStringLength {
return "", fmt.Errorf("string length %d exceeds maximum %d", length, maxStringLength) return "", coreerr.E("gguf.readString", fmt.Sprintf("string length %d exceeds maximum %d", length, maxStringLength), nil)
} }
buf := make([]byte, length) buf := make([]byte, length)
if _, err := io.ReadFull(r, buf); err != nil { if _, err := io.ReadFull(r, buf); err != nil {
@ -302,7 +304,7 @@ func skipValue(r io.Reader, valType uint32) error {
return err return err
} }
if length > maxStringLength { if length > maxStringLength {
return fmt.Errorf("string length %d exceeds maximum %d", length, maxStringLength) return coreerr.E("gguf.skipValue", fmt.Sprintf("string length %d exceeds maximum %d", length, maxStringLength), nil)
} }
_, err := readN(r, int64(length)) _, err := readN(r, int64(length))
return err return err
@ -322,7 +324,7 @@ func skipValue(r io.Reader, valType uint32) error {
} }
return nil return nil
default: default:
return fmt.Errorf("unknown GGUF value type: %d", valType) return coreerr.E("gguf.skipValue", fmt.Sprintf("unknown GGUF value type: %d", valType), nil)
} }
} }

View file

@ -11,6 +11,8 @@ import (
"net/http" "net/http"
"strings" "strings"
"sync" "sync"
coreerr "forge.lthn.ai/core/go-log"
) )
// ChatMessage is a single message in a conversation. // ChatMessage is a single message in a conversation.
@ -65,26 +67,26 @@ func (c *Client) ChatComplete(ctx context.Context, req ChatRequest) (iter.Seq[st
body, err := json.Marshal(req) body, err := json.Marshal(req)
if err != nil { if err != nil {
return noChunks, func() error { return fmt.Errorf("llamacpp: marshal chat request: %w", err) } return noChunks, func() error { return coreerr.E("llamacpp.ChatComplete", "marshal chat request", err) }
} }
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+"/v1/chat/completions", bytes.NewReader(body)) httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+"/v1/chat/completions", bytes.NewReader(body))
if err != nil { if err != nil {
return noChunks, func() error { return fmt.Errorf("llamacpp: create chat request: %w", err) } return noChunks, func() error { return coreerr.E("llamacpp.ChatComplete", "create chat request", err) }
} }
httpReq.Header.Set("Content-Type", "application/json") httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("Accept", "text/event-stream") httpReq.Header.Set("Accept", "text/event-stream")
resp, err := c.httpClient.Do(httpReq) resp, err := c.httpClient.Do(httpReq)
if err != nil { if err != nil {
return noChunks, func() error { return fmt.Errorf("llamacpp: chat request: %w", err) } return noChunks, func() error { return coreerr.E("llamacpp.ChatComplete", "chat request", err) }
} }
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
defer resp.Body.Close() defer resp.Body.Close()
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 256)) respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 256))
return noChunks, func() error { return noChunks, func() error {
return fmt.Errorf("llamacpp: chat returned %d: %s", resp.StatusCode, strings.TrimSpace(string(respBody))) return coreerr.E("llamacpp.ChatComplete", fmt.Sprintf("chat returned %d: %s", resp.StatusCode, strings.TrimSpace(string(respBody))), nil)
} }
} }
@ -100,7 +102,7 @@ func (c *Client) ChatComplete(ctx context.Context, req ChatRequest) (iter.Seq[st
for raw := range sseData { for raw := range sseData {
var chunk chatChunkResponse var chunk chatChunkResponse
if err := json.Unmarshal([]byte(raw), &chunk); err != nil { if err := json.Unmarshal([]byte(raw), &chunk); err != nil {
streamErr = fmt.Errorf("llamacpp: decode chat chunk: %w", err) streamErr = coreerr.E("llamacpp.ChatComplete", "decode chat chunk", err)
return return
} }
if len(chunk.Choices) == 0 { if len(chunk.Choices) == 0 {
@ -130,26 +132,26 @@ func (c *Client) Complete(ctx context.Context, req CompletionRequest) (iter.Seq[
body, err := json.Marshal(req) body, err := json.Marshal(req)
if err != nil { if err != nil {
return noChunks, func() error { return fmt.Errorf("llamacpp: marshal completion request: %w", err) } return noChunks, func() error { return coreerr.E("llamacpp.Complete", "marshal completion request", err) }
} }
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+"/v1/completions", bytes.NewReader(body)) httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+"/v1/completions", bytes.NewReader(body))
if err != nil { if err != nil {
return noChunks, func() error { return fmt.Errorf("llamacpp: create completion request: %w", err) } return noChunks, func() error { return coreerr.E("llamacpp.Complete", "create completion request", err) }
} }
httpReq.Header.Set("Content-Type", "application/json") httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("Accept", "text/event-stream") httpReq.Header.Set("Accept", "text/event-stream")
resp, err := c.httpClient.Do(httpReq) resp, err := c.httpClient.Do(httpReq)
if err != nil { if err != nil {
return noChunks, func() error { return fmt.Errorf("llamacpp: completion request: %w", err) } return noChunks, func() error { return coreerr.E("llamacpp.Complete", "completion request", err) }
} }
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
defer resp.Body.Close() defer resp.Body.Close()
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 256)) respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 256))
return noChunks, func() error { return noChunks, func() error {
return fmt.Errorf("llamacpp: completion returned %d: %s", resp.StatusCode, strings.TrimSpace(string(respBody))) return coreerr.E("llamacpp.Complete", fmt.Sprintf("completion returned %d: %s", resp.StatusCode, strings.TrimSpace(string(respBody))), nil)
} }
} }
@ -165,7 +167,7 @@ func (c *Client) Complete(ctx context.Context, req CompletionRequest) (iter.Seq[
for raw := range sseData { for raw := range sseData {
var chunk completionChunkResponse var chunk completionChunkResponse
if err := json.Unmarshal([]byte(raw), &chunk); err != nil { if err := json.Unmarshal([]byte(raw), &chunk); err != nil {
streamErr = fmt.Errorf("llamacpp: decode completion chunk: %w", err) streamErr = coreerr.E("llamacpp.Complete", "decode completion chunk", err)
return return
} }
if len(chunk.Choices) == 0 { if len(chunk.Choices) == 0 {
@ -207,7 +209,7 @@ func parseSSE(r io.Reader, errOut *error) iter.Seq[string] {
} }
} }
if err := scanner.Err(); err != nil { if err := scanner.Err(); err != nil {
*errOut = fmt.Errorf("llamacpp: read SSE stream: %w", err) *errOut = coreerr.E("llamacpp.parseSSE", "read SSE stream", err)
} }
} }
} }

View file

@ -7,6 +7,8 @@ import (
"io" "io"
"net/http" "net/http"
"strings" "strings"
coreerr "forge.lthn.ai/core/go-log"
) )
// Client communicates with a llama-server instance. // Client communicates with a llama-server instance.
@ -41,14 +43,14 @@ func (c *Client) Health(ctx context.Context) error {
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 256)) body, _ := io.ReadAll(io.LimitReader(resp.Body, 256))
return fmt.Errorf("llamacpp: health returned %d: %s", resp.StatusCode, string(body)) return coreerr.E("llamacpp.Health", fmt.Sprintf("health returned %d: %s", resp.StatusCode, string(body)), nil)
} }
var h healthResponse var h healthResponse
if err := json.NewDecoder(resp.Body).Decode(&h); err != nil { if err := json.NewDecoder(resp.Body).Decode(&h); err != nil {
return fmt.Errorf("llamacpp: health decode: %w", err) return coreerr.E("llamacpp.Health", "health decode", err)
} }
if h.Status != "ok" { if h.Status != "ok" {
return fmt.Errorf("llamacpp: server not ready (status: %s)", h.Status) return coreerr.E("llamacpp.Health", fmt.Sprintf("server not ready (status: %s)", h.Status), nil)
} }
return nil return nil
} }

View file

@ -10,6 +10,7 @@ import (
"sync" "sync"
"time" "time"
coreerr "forge.lthn.ai/core/go-log"
"forge.lthn.ai/core/go-inference" "forge.lthn.ai/core/go-inference"
"forge.lthn.ai/core/go-rocm/internal/llamacpp" "forge.lthn.ai/core/go-rocm/internal/llamacpp"
) )
@ -148,7 +149,7 @@ func (m *rocmModel) Classify(ctx context.Context, prompts []string, opts ...infe
text.WriteString(chunk) text.WriteString(chunk)
} }
if err := errFn(); err != nil { if err := errFn(); err != nil {
return nil, fmt.Errorf("rocm: classify prompt %d: %w", i, err) return nil, coreerr.E("rocm.Classify", fmt.Sprintf("classify prompt %d", i), err)
} }
results[i] = inference.ClassifyResult{ results[i] = inference.ClassifyResult{
@ -194,7 +195,7 @@ func (m *rocmModel) BatchGenerate(ctx context.Context, prompts []string, opts ..
tokens = append(tokens, inference.Token{Text: text}) tokens = append(tokens, inference.Token{Text: text})
} }
if err := errFn(); err != nil { if err := errFn(); err != nil {
results[i].Err = fmt.Errorf("rocm: batch prompt %d: %w", i, err) results[i].Err = coreerr.E("rocm.BatchGenerate", fmt.Sprintf("batch prompt %d", i), err)
} }
results[i].Tokens = tokens results[i].Tokens = tokens
totalGenerated += len(tokens) totalGenerated += len(tokens)
@ -234,9 +235,9 @@ func (m *rocmModel) setServerExitErr() {
m.mu.Lock() m.mu.Lock()
defer m.mu.Unlock() defer m.mu.Unlock()
if m.srv.exitErr != nil { if m.srv.exitErr != nil {
m.lastErr = fmt.Errorf("rocm: server has exited: %w", m.srv.exitErr) m.lastErr = coreerr.E("rocm.setServerExitErr", "server has exited", m.srv.exitErr)
} else { } else {
m.lastErr = fmt.Errorf("rocm: server has exited unexpectedly") m.lastErr = coreerr.E("rocm.setServerExitErr", "server has exited unexpectedly", nil)
} }
} }

View file

@ -2,7 +2,7 @@
package rocm package rocm
import "fmt" import coreerr "forge.lthn.ai/core/go-log"
// ROCmAvailable reports whether ROCm GPU inference is available. // ROCmAvailable reports whether ROCm GPU inference is available.
// Returns false on non-Linux or non-amd64 platforms. // Returns false on non-Linux or non-amd64 platforms.
@ -10,5 +10,5 @@ func ROCmAvailable() bool { return false }
// GetVRAMInfo is not available on non-Linux/non-amd64 platforms. // GetVRAMInfo is not available on non-Linux/non-amd64 platforms.
func GetVRAMInfo() (VRAMInfo, error) { func GetVRAMInfo() (VRAMInfo, error) {
return VRAMInfo{}, fmt.Errorf("rocm: VRAM monitoring not available on this platform") return VRAMInfo{}, coreerr.E("rocm.GetVRAMInfo", "VRAM monitoring not available on this platform", nil)
} }

View file

@ -13,6 +13,7 @@ import (
"syscall" "syscall"
"time" "time"
coreerr "forge.lthn.ai/core/go-log"
"forge.lthn.ai/core/go-rocm/internal/llamacpp" "forge.lthn.ai/core/go-rocm/internal/llamacpp"
) )
@ -40,13 +41,13 @@ func (s *server) alive() bool {
func findLlamaServer() (string, error) { func findLlamaServer() (string, error) {
if p := os.Getenv("ROCM_LLAMA_SERVER_PATH"); p != "" { if p := os.Getenv("ROCM_LLAMA_SERVER_PATH"); p != "" {
if _, err := os.Stat(p); err != nil { if _, err := os.Stat(p); err != nil {
return "", fmt.Errorf("llama-server not found at ROCM_LLAMA_SERVER_PATH=%s: %w", p, err) return "", coreerr.E("rocm.findLlamaServer", "llama-server not found at ROCM_LLAMA_SERVER_PATH="+p, err)
} }
return p, nil return p, nil
} }
p, err := exec.LookPath("llama-server") p, err := exec.LookPath("llama-server")
if err != nil { if err != nil {
return "", fmt.Errorf("llama-server not found in PATH: %w", err) return "", coreerr.E("rocm.findLlamaServer", "llama-server not found in PATH", err)
} }
return p, nil return p, nil
} }
@ -55,7 +56,7 @@ func findLlamaServer() (string, error) {
func freePort() (int, error) { func freePort() (int, error) {
ln, err := net.Listen("tcp", "127.0.0.1:0") ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil { if err != nil {
return 0, fmt.Errorf("freePort: %w", err) return 0, coreerr.E("rocm.freePort", "listen for free port", err)
} }
port := ln.Addr().(*net.TCPAddr).Port port := ln.Addr().(*net.TCPAddr).Port
ln.Close() ln.Close()
@ -92,7 +93,7 @@ func startServer(binary, modelPath string, gpuLayers, ctxSize, parallelSlots int
for attempt := range maxAttempts { for attempt := range maxAttempts {
port, err := freePort() port, err := freePort()
if err != nil { if err != nil {
return nil, fmt.Errorf("rocm: find free port: %w", err) return nil, coreerr.E("rocm.startServer", "find free port", err)
} }
args := []string{ args := []string{
@ -112,7 +113,7 @@ func startServer(binary, modelPath string, gpuLayers, ctxSize, parallelSlots int
cmd.Env = serverEnv() cmd.Env = serverEnv()
if err := cmd.Start(); err != nil { if err := cmd.Start(); err != nil {
return nil, fmt.Errorf("start llama-server: %w", err) return nil, coreerr.E("rocm.startServer", "start llama-server", err)
} }
s := &server{ s := &server{
@ -139,15 +140,15 @@ func startServer(binary, modelPath string, gpuLayers, ctxSize, parallelSlots int
select { select {
case <-s.exited: case <-s.exited:
_ = s.stop() _ = s.stop()
lastErr = fmt.Errorf("attempt %d: %w", attempt+1, err) lastErr = coreerr.E("rocm.startServer", fmt.Sprintf("attempt %d", attempt+1), err)
continue continue
default: default:
_ = s.stop() _ = s.stop()
return nil, fmt.Errorf("rocm: llama-server not ready: %w", err) return nil, coreerr.E("rocm.startServer", "llama-server not ready", err)
} }
} }
return nil, fmt.Errorf("rocm: server failed after %d attempts: %w", maxAttempts, lastErr) return nil, coreerr.E("rocm.startServer", fmt.Sprintf("server failed after %d attempts", maxAttempts), lastErr)
} }
// waitReady polls the health endpoint until the server is ready. // waitReady polls the health endpoint until the server is ready.
@ -158,9 +159,9 @@ func (s *server) waitReady(ctx context.Context) error {
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return fmt.Errorf("timeout waiting for llama-server: %w", ctx.Err()) return coreerr.E("server.waitReady", "timeout waiting for llama-server", ctx.Err())
case <-s.exited: case <-s.exited:
return fmt.Errorf("llama-server exited before becoming ready: %v", s.exitErr) return coreerr.E("server.waitReady", "llama-server exited before becoming ready", s.exitErr)
case <-ticker.C: case <-ticker.C:
if err := s.client.Health(ctx); err == nil { if err := s.client.Health(ctx); err == nil {
return nil return nil
@ -184,7 +185,7 @@ func (s *server) stop() error {
// Send SIGTERM for graceful shutdown. // Send SIGTERM for graceful shutdown.
if err := s.cmd.Process.Signal(syscall.SIGTERM); err != nil { if err := s.cmd.Process.Signal(syscall.SIGTERM); err != nil {
return fmt.Errorf("sigterm llama-server: %w", err) return coreerr.E("server.stop", "sigterm llama-server", err)
} }
// Wait up to 5 seconds for clean exit. // Wait up to 5 seconds for clean exit.
@ -194,7 +195,7 @@ func (s *server) stop() error {
case <-time.After(5 * time.Second): case <-time.After(5 * time.Second):
// Force kill. // Force kill.
if err := s.cmd.Process.Kill(); err != nil { if err := s.cmd.Process.Kill(); err != nil {
return fmt.Errorf("kill llama-server: %w", err) return coreerr.E("server.stop", "kill llama-server", err)
} }
<-s.exited <-s.exited
return s.exitErr return s.exitErr

11
vram.go
View file

@ -3,11 +3,12 @@
package rocm package rocm
import ( import (
"fmt"
"os" "os"
"path/filepath" "path/filepath"
"strconv" "strconv"
"strings" "strings"
coreerr "forge.lthn.ai/core/go-log"
) )
// GetVRAMInfo reads VRAM usage for the discrete GPU from sysfs. // GetVRAMInfo reads VRAM usage for the discrete GPU from sysfs.
@ -19,10 +20,10 @@ import (
func GetVRAMInfo() (VRAMInfo, error) { func GetVRAMInfo() (VRAMInfo, error) {
cards, err := filepath.Glob("/sys/class/drm/card[0-9]*/device/mem_info_vram_total") cards, err := filepath.Glob("/sys/class/drm/card[0-9]*/device/mem_info_vram_total")
if err != nil { if err != nil {
return VRAMInfo{}, fmt.Errorf("rocm: glob vram sysfs: %w", err) return VRAMInfo{}, coreerr.E("rocm.GetVRAMInfo", "glob vram sysfs", err)
} }
if len(cards) == 0 { if len(cards) == 0 {
return VRAMInfo{}, fmt.Errorf("rocm: no GPU VRAM info found in sysfs") return VRAMInfo{}, coreerr.E("rocm.GetVRAMInfo", "no GPU VRAM info found in sysfs", nil)
} }
var bestDir string var bestDir string
@ -40,12 +41,12 @@ func GetVRAMInfo() (VRAMInfo, error) {
} }
if bestDir == "" { if bestDir == "" {
return VRAMInfo{}, fmt.Errorf("rocm: no readable VRAM sysfs entries") return VRAMInfo{}, coreerr.E("rocm.GetVRAMInfo", "no readable VRAM sysfs entries", nil)
} }
used, err := readSysfsUint64(filepath.Join(bestDir, "mem_info_vram_used")) used, err := readSysfsUint64(filepath.Join(bestDir, "mem_info_vram_used"))
if err != nil { if err != nil {
return VRAMInfo{}, fmt.Errorf("rocm: read vram used: %w", err) return VRAMInfo{}, coreerr.E("rocm.GetVRAMInfo", "read vram used", err)
} }
free := uint64(0) free := uint64(0)