diff --git a/backend.go b/backend.go index 12b3fad..8f1bdcf 100644 --- a/backend.go +++ b/backend.go @@ -3,10 +3,10 @@ package rocm import ( - "fmt" "os" "strings" + coreerr "forge.lthn.ai/core/go-log" "forge.lthn.ai/core/go-inference" "forge.lthn.ai/core/go-rocm/internal/gguf" ) @@ -42,7 +42,7 @@ func (b *rocmBackend) LoadModel(path string, opts ...inference.LoadOption) (infe meta, err := gguf.ReadMetadata(path) if err != nil { - return nil, fmt.Errorf("rocm: read model metadata: %w", err) + return nil, coreerr.E("rocm.LoadModel", "read model metadata", err) } ctxLen := cfg.ContextLen diff --git a/go.mod b/go.mod index 96f04ec..790d763 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,10 @@ go 1.26.0 require forge.lthn.ai/core/go-inference v0.0.0 -require github.com/kr/text v0.2.0 // indirect +require ( + forge.lthn.ai/core/go-log v0.0.4 // indirect + github.com/kr/text v0.2.0 // indirect +) require ( github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect diff --git a/go.sum b/go.sum index 072f3cd..ffc8d33 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +forge.lthn.ai/core/go-log v0.0.4 h1:KTuCEPgFmuM8KJfnyQ8vPOU1Jg654W74h8IJvfQMfv0= +forge.lthn.ai/core/go-log v0.0.4/go.mod h1:r14MXKOD3LF/sI8XUJQhRk/SZHBE7jAFVuCfgkXoZPw= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= diff --git a/internal/gguf/gguf.go b/internal/gguf/gguf.go index 0c908c8..28a290e 100644 --- a/internal/gguf/gguf.go +++ b/internal/gguf/gguf.go @@ -15,6 +15,8 @@ import ( "math" "os" "strings" + + coreerr "forge.lthn.ai/core/go-log" ) // ggufMagic is the GGUF file magic number: "GGUF" in little-endian. @@ -96,37 +98,37 @@ func ReadMetadata(path string) (Metadata, error) { // Read and validate magic number. var magic uint32 if err := binary.Read(r, binary.LittleEndian, &magic); err != nil { - return Metadata{}, fmt.Errorf("reading magic: %w", err) + return Metadata{}, coreerr.E("gguf.ReadMetadata", "reading magic", err) } if magic != ggufMagic { - return Metadata{}, fmt.Errorf("invalid magic: 0x%08X (expected 0x%08X)", magic, ggufMagic) + return Metadata{}, coreerr.E("gguf.ReadMetadata", fmt.Sprintf("invalid magic: 0x%08X (expected 0x%08X)", magic, ggufMagic), nil) } // Read version. var version uint32 if err := binary.Read(r, binary.LittleEndian, &version); err != nil { - return Metadata{}, fmt.Errorf("reading version: %w", err) + return Metadata{}, coreerr.E("gguf.ReadMetadata", "reading version", err) } if version < 2 || version > 3 { - return Metadata{}, fmt.Errorf("unsupported GGUF version: %d", version) + return Metadata{}, coreerr.E("gguf.ReadMetadata", fmt.Sprintf("unsupported GGUF version: %d", version), nil) } // Read tensor count and KV count. v3 uses uint64, v2 uses uint32. var tensorCount, kvCount uint64 if version == 3 { if err := binary.Read(r, binary.LittleEndian, &tensorCount); err != nil { - return Metadata{}, fmt.Errorf("reading tensor count: %w", err) + return Metadata{}, coreerr.E("gguf.ReadMetadata", "reading tensor count", err) } if err := binary.Read(r, binary.LittleEndian, &kvCount); err != nil { - return Metadata{}, fmt.Errorf("reading kv count: %w", err) + return Metadata{}, coreerr.E("gguf.ReadMetadata", "reading kv count", err) } } else { var tc, kc uint32 if err := binary.Read(r, binary.LittleEndian, &tc); err != nil { - return Metadata{}, fmt.Errorf("reading tensor count: %w", err) + return Metadata{}, coreerr.E("gguf.ReadMetadata", "reading tensor count", err) } if err := binary.Read(r, binary.LittleEndian, &kc); err != nil { - return Metadata{}, fmt.Errorf("reading kv count: %w", err) + return Metadata{}, coreerr.E("gguf.ReadMetadata", "reading kv count", err) } tensorCount = uint64(tc) kvCount = uint64(kc) @@ -148,12 +150,12 @@ func ReadMetadata(path string) (Metadata, error) { for i := uint64(0); i < kvCount; i++ { key, err := readString(r) if err != nil { - return Metadata{}, fmt.Errorf("reading key %d: %w", i, err) + return Metadata{}, coreerr.E("gguf.ReadMetadata", fmt.Sprintf("reading key %d", i), err) } var valType uint32 if err := binary.Read(r, binary.LittleEndian, &valType); err != nil { - return Metadata{}, fmt.Errorf("reading value type for key %q: %w", key, err) + return Metadata{}, coreerr.E("gguf.ReadMetadata", fmt.Sprintf("reading value type for key %q", key), err) } // Check whether this is an interesting key before reading the value. @@ -161,7 +163,7 @@ func ReadMetadata(path string) (Metadata, error) { case key == "general.architecture": v, err := readTypedValue(r, valType) if err != nil { - return Metadata{}, fmt.Errorf("reading value for key %q: %w", key, err) + return Metadata{}, coreerr.E("gguf.ReadMetadata", fmt.Sprintf("reading value for key %q", key), err) } if s, ok := v.(string); ok { meta.Architecture = s @@ -170,7 +172,7 @@ func ReadMetadata(path string) (Metadata, error) { case key == "general.name": v, err := readTypedValue(r, valType) if err != nil { - return Metadata{}, fmt.Errorf("reading value for key %q: %w", key, err) + return Metadata{}, coreerr.E("gguf.ReadMetadata", fmt.Sprintf("reading value for key %q", key), err) } if s, ok := v.(string); ok { meta.Name = s @@ -179,7 +181,7 @@ func ReadMetadata(path string) (Metadata, error) { case key == "general.file_type": v, err := readTypedValue(r, valType) if err != nil { - return Metadata{}, fmt.Errorf("reading value for key %q: %w", key, err) + return Metadata{}, coreerr.E("gguf.ReadMetadata", fmt.Sprintf("reading value for key %q", key), err) } if u, ok := v.(uint32); ok { meta.FileType = u @@ -188,7 +190,7 @@ func ReadMetadata(path string) (Metadata, error) { case key == "general.size_label": v, err := readTypedValue(r, valType) if err != nil { - return Metadata{}, fmt.Errorf("reading value for key %q: %w", key, err) + return Metadata{}, coreerr.E("gguf.ReadMetadata", fmt.Sprintf("reading value for key %q", key), err) } if s, ok := v.(string); ok { meta.SizeLabel = s @@ -197,7 +199,7 @@ func ReadMetadata(path string) (Metadata, error) { case strings.HasSuffix(key, ".context_length"): v, err := readTypedValue(r, valType) if err != nil { - return Metadata{}, fmt.Errorf("reading value for key %q: %w", key, err) + return Metadata{}, coreerr.E("gguf.ReadMetadata", fmt.Sprintf("reading value for key %q", key), err) } if u, ok := v.(uint32); ok { candidateContextLength[key] = u @@ -206,7 +208,7 @@ func ReadMetadata(path string) (Metadata, error) { case strings.HasSuffix(key, ".block_count"): v, err := readTypedValue(r, valType) if err != nil { - return Metadata{}, fmt.Errorf("reading value for key %q: %w", key, err) + return Metadata{}, coreerr.E("gguf.ReadMetadata", fmt.Sprintf("reading value for key %q", key), err) } if u, ok := v.(uint32); ok { candidateBlockCount[key] = u @@ -215,7 +217,7 @@ func ReadMetadata(path string) (Metadata, error) { default: // Skip uninteresting value. if err := skipValue(r, valType); err != nil { - return Metadata{}, fmt.Errorf("skipping value for key %q: %w", key, err) + return Metadata{}, coreerr.E("gguf.ReadMetadata", fmt.Sprintf("skipping value for key %q", key), err) } } } @@ -245,7 +247,7 @@ func readString(r io.Reader) (string, error) { return "", err } if length > maxStringLength { - return "", fmt.Errorf("string length %d exceeds maximum %d", length, maxStringLength) + return "", coreerr.E("gguf.readString", fmt.Sprintf("string length %d exceeds maximum %d", length, maxStringLength), nil) } buf := make([]byte, length) if _, err := io.ReadFull(r, buf); err != nil { @@ -302,7 +304,7 @@ func skipValue(r io.Reader, valType uint32) error { return err } if length > maxStringLength { - return fmt.Errorf("string length %d exceeds maximum %d", length, maxStringLength) + return coreerr.E("gguf.skipValue", fmt.Sprintf("string length %d exceeds maximum %d", length, maxStringLength), nil) } _, err := readN(r, int64(length)) return err @@ -322,7 +324,7 @@ func skipValue(r io.Reader, valType uint32) error { } return nil default: - return fmt.Errorf("unknown GGUF value type: %d", valType) + return coreerr.E("gguf.skipValue", fmt.Sprintf("unknown GGUF value type: %d", valType), nil) } } diff --git a/internal/llamacpp/client.go b/internal/llamacpp/client.go index bc2577d..2fb0c11 100644 --- a/internal/llamacpp/client.go +++ b/internal/llamacpp/client.go @@ -11,6 +11,8 @@ import ( "net/http" "strings" "sync" + + coreerr "forge.lthn.ai/core/go-log" ) // ChatMessage is a single message in a conversation. @@ -65,26 +67,26 @@ func (c *Client) ChatComplete(ctx context.Context, req ChatRequest) (iter.Seq[st body, err := json.Marshal(req) if err != nil { - return noChunks, func() error { return fmt.Errorf("llamacpp: marshal chat request: %w", err) } + return noChunks, func() error { return coreerr.E("llamacpp.ChatComplete", "marshal chat request", err) } } httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+"/v1/chat/completions", bytes.NewReader(body)) if err != nil { - return noChunks, func() error { return fmt.Errorf("llamacpp: create chat request: %w", err) } + return noChunks, func() error { return coreerr.E("llamacpp.ChatComplete", "create chat request", err) } } httpReq.Header.Set("Content-Type", "application/json") httpReq.Header.Set("Accept", "text/event-stream") resp, err := c.httpClient.Do(httpReq) if err != nil { - return noChunks, func() error { return fmt.Errorf("llamacpp: chat request: %w", err) } + return noChunks, func() error { return coreerr.E("llamacpp.ChatComplete", "chat request", err) } } if resp.StatusCode != http.StatusOK { defer resp.Body.Close() respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 256)) return noChunks, func() error { - return fmt.Errorf("llamacpp: chat returned %d: %s", resp.StatusCode, strings.TrimSpace(string(respBody))) + return coreerr.E("llamacpp.ChatComplete", fmt.Sprintf("chat returned %d: %s", resp.StatusCode, strings.TrimSpace(string(respBody))), nil) } } @@ -100,7 +102,7 @@ func (c *Client) ChatComplete(ctx context.Context, req ChatRequest) (iter.Seq[st for raw := range sseData { var chunk chatChunkResponse if err := json.Unmarshal([]byte(raw), &chunk); err != nil { - streamErr = fmt.Errorf("llamacpp: decode chat chunk: %w", err) + streamErr = coreerr.E("llamacpp.ChatComplete", "decode chat chunk", err) return } if len(chunk.Choices) == 0 { @@ -130,26 +132,26 @@ func (c *Client) Complete(ctx context.Context, req CompletionRequest) (iter.Seq[ body, err := json.Marshal(req) if err != nil { - return noChunks, func() error { return fmt.Errorf("llamacpp: marshal completion request: %w", err) } + return noChunks, func() error { return coreerr.E("llamacpp.Complete", "marshal completion request", err) } } httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+"/v1/completions", bytes.NewReader(body)) if err != nil { - return noChunks, func() error { return fmt.Errorf("llamacpp: create completion request: %w", err) } + return noChunks, func() error { return coreerr.E("llamacpp.Complete", "create completion request", err) } } httpReq.Header.Set("Content-Type", "application/json") httpReq.Header.Set("Accept", "text/event-stream") resp, err := c.httpClient.Do(httpReq) if err != nil { - return noChunks, func() error { return fmt.Errorf("llamacpp: completion request: %w", err) } + return noChunks, func() error { return coreerr.E("llamacpp.Complete", "completion request", err) } } if resp.StatusCode != http.StatusOK { defer resp.Body.Close() respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 256)) return noChunks, func() error { - return fmt.Errorf("llamacpp: completion returned %d: %s", resp.StatusCode, strings.TrimSpace(string(respBody))) + return coreerr.E("llamacpp.Complete", fmt.Sprintf("completion returned %d: %s", resp.StatusCode, strings.TrimSpace(string(respBody))), nil) } } @@ -165,7 +167,7 @@ func (c *Client) Complete(ctx context.Context, req CompletionRequest) (iter.Seq[ for raw := range sseData { var chunk completionChunkResponse if err := json.Unmarshal([]byte(raw), &chunk); err != nil { - streamErr = fmt.Errorf("llamacpp: decode completion chunk: %w", err) + streamErr = coreerr.E("llamacpp.Complete", "decode completion chunk", err) return } if len(chunk.Choices) == 0 { @@ -207,7 +209,7 @@ func parseSSE(r io.Reader, errOut *error) iter.Seq[string] { } } if err := scanner.Err(); err != nil { - *errOut = fmt.Errorf("llamacpp: read SSE stream: %w", err) + *errOut = coreerr.E("llamacpp.parseSSE", "read SSE stream", err) } } } diff --git a/internal/llamacpp/health.go b/internal/llamacpp/health.go index 12d790f..33ec57b 100644 --- a/internal/llamacpp/health.go +++ b/internal/llamacpp/health.go @@ -7,6 +7,8 @@ import ( "io" "net/http" "strings" + + coreerr "forge.lthn.ai/core/go-log" ) // Client communicates with a llama-server instance. @@ -41,14 +43,14 @@ func (c *Client) Health(ctx context.Context) error { if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(io.LimitReader(resp.Body, 256)) - return fmt.Errorf("llamacpp: health returned %d: %s", resp.StatusCode, string(body)) + return coreerr.E("llamacpp.Health", fmt.Sprintf("health returned %d: %s", resp.StatusCode, string(body)), nil) } var h healthResponse if err := json.NewDecoder(resp.Body).Decode(&h); err != nil { - return fmt.Errorf("llamacpp: health decode: %w", err) + return coreerr.E("llamacpp.Health", "health decode", err) } if h.Status != "ok" { - return fmt.Errorf("llamacpp: server not ready (status: %s)", h.Status) + return coreerr.E("llamacpp.Health", fmt.Sprintf("server not ready (status: %s)", h.Status), nil) } return nil } diff --git a/model.go b/model.go index d01e08f..5ae5c47 100644 --- a/model.go +++ b/model.go @@ -10,6 +10,7 @@ import ( "sync" "time" + coreerr "forge.lthn.ai/core/go-log" "forge.lthn.ai/core/go-inference" "forge.lthn.ai/core/go-rocm/internal/llamacpp" ) @@ -148,7 +149,7 @@ func (m *rocmModel) Classify(ctx context.Context, prompts []string, opts ...infe text.WriteString(chunk) } if err := errFn(); err != nil { - return nil, fmt.Errorf("rocm: classify prompt %d: %w", i, err) + return nil, coreerr.E("rocm.Classify", fmt.Sprintf("classify prompt %d", i), err) } results[i] = inference.ClassifyResult{ @@ -194,7 +195,7 @@ func (m *rocmModel) BatchGenerate(ctx context.Context, prompts []string, opts .. tokens = append(tokens, inference.Token{Text: text}) } if err := errFn(); err != nil { - results[i].Err = fmt.Errorf("rocm: batch prompt %d: %w", i, err) + results[i].Err = coreerr.E("rocm.BatchGenerate", fmt.Sprintf("batch prompt %d", i), err) } results[i].Tokens = tokens totalGenerated += len(tokens) @@ -234,9 +235,9 @@ func (m *rocmModel) setServerExitErr() { m.mu.Lock() defer m.mu.Unlock() if m.srv.exitErr != nil { - m.lastErr = fmt.Errorf("rocm: server has exited: %w", m.srv.exitErr) + m.lastErr = coreerr.E("rocm.setServerExitErr", "server has exited", m.srv.exitErr) } else { - m.lastErr = fmt.Errorf("rocm: server has exited unexpectedly") + m.lastErr = coreerr.E("rocm.setServerExitErr", "server has exited unexpectedly", nil) } } diff --git a/rocm_stub.go b/rocm_stub.go index 34475e5..0947fe7 100644 --- a/rocm_stub.go +++ b/rocm_stub.go @@ -2,7 +2,7 @@ package rocm -import "fmt" +import coreerr "forge.lthn.ai/core/go-log" // ROCmAvailable reports whether ROCm GPU inference is available. // Returns false on non-Linux or non-amd64 platforms. @@ -10,5 +10,5 @@ func ROCmAvailable() bool { return false } // GetVRAMInfo is not available on non-Linux/non-amd64 platforms. func GetVRAMInfo() (VRAMInfo, error) { - return VRAMInfo{}, fmt.Errorf("rocm: VRAM monitoring not available on this platform") + return VRAMInfo{}, coreerr.E("rocm.GetVRAMInfo", "VRAM monitoring not available on this platform", nil) } diff --git a/server.go b/server.go index c5b8b7b..071e759 100644 --- a/server.go +++ b/server.go @@ -13,6 +13,7 @@ import ( "syscall" "time" + coreerr "forge.lthn.ai/core/go-log" "forge.lthn.ai/core/go-rocm/internal/llamacpp" ) @@ -40,13 +41,13 @@ func (s *server) alive() bool { func findLlamaServer() (string, error) { if p := os.Getenv("ROCM_LLAMA_SERVER_PATH"); p != "" { if _, err := os.Stat(p); err != nil { - return "", fmt.Errorf("llama-server not found at ROCM_LLAMA_SERVER_PATH=%s: %w", p, err) + return "", coreerr.E("rocm.findLlamaServer", "llama-server not found at ROCM_LLAMA_SERVER_PATH="+p, err) } return p, nil } p, err := exec.LookPath("llama-server") if err != nil { - return "", fmt.Errorf("llama-server not found in PATH: %w", err) + return "", coreerr.E("rocm.findLlamaServer", "llama-server not found in PATH", err) } return p, nil } @@ -55,7 +56,7 @@ func findLlamaServer() (string, error) { func freePort() (int, error) { ln, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { - return 0, fmt.Errorf("freePort: %w", err) + return 0, coreerr.E("rocm.freePort", "listen for free port", err) } port := ln.Addr().(*net.TCPAddr).Port ln.Close() @@ -92,7 +93,7 @@ func startServer(binary, modelPath string, gpuLayers, ctxSize, parallelSlots int for attempt := range maxAttempts { port, err := freePort() if err != nil { - return nil, fmt.Errorf("rocm: find free port: %w", err) + return nil, coreerr.E("rocm.startServer", "find free port", err) } args := []string{ @@ -112,7 +113,7 @@ func startServer(binary, modelPath string, gpuLayers, ctxSize, parallelSlots int cmd.Env = serverEnv() if err := cmd.Start(); err != nil { - return nil, fmt.Errorf("start llama-server: %w", err) + return nil, coreerr.E("rocm.startServer", "start llama-server", err) } s := &server{ @@ -139,15 +140,15 @@ func startServer(binary, modelPath string, gpuLayers, ctxSize, parallelSlots int select { case <-s.exited: _ = s.stop() - lastErr = fmt.Errorf("attempt %d: %w", attempt+1, err) + lastErr = coreerr.E("rocm.startServer", fmt.Sprintf("attempt %d", attempt+1), err) continue default: _ = s.stop() - return nil, fmt.Errorf("rocm: llama-server not ready: %w", err) + return nil, coreerr.E("rocm.startServer", "llama-server not ready", err) } } - return nil, fmt.Errorf("rocm: server failed after %d attempts: %w", maxAttempts, lastErr) + return nil, coreerr.E("rocm.startServer", fmt.Sprintf("server failed after %d attempts", maxAttempts), lastErr) } // waitReady polls the health endpoint until the server is ready. @@ -158,9 +159,9 @@ func (s *server) waitReady(ctx context.Context) error { for { select { case <-ctx.Done(): - return fmt.Errorf("timeout waiting for llama-server: %w", ctx.Err()) + return coreerr.E("server.waitReady", "timeout waiting for llama-server", ctx.Err()) case <-s.exited: - return fmt.Errorf("llama-server exited before becoming ready: %v", s.exitErr) + return coreerr.E("server.waitReady", "llama-server exited before becoming ready", s.exitErr) case <-ticker.C: if err := s.client.Health(ctx); err == nil { return nil @@ -184,7 +185,7 @@ func (s *server) stop() error { // Send SIGTERM for graceful shutdown. if err := s.cmd.Process.Signal(syscall.SIGTERM); err != nil { - return fmt.Errorf("sigterm llama-server: %w", err) + return coreerr.E("server.stop", "sigterm llama-server", err) } // Wait up to 5 seconds for clean exit. @@ -194,7 +195,7 @@ func (s *server) stop() error { case <-time.After(5 * time.Second): // Force kill. if err := s.cmd.Process.Kill(); err != nil { - return fmt.Errorf("kill llama-server: %w", err) + return coreerr.E("server.stop", "kill llama-server", err) } <-s.exited return s.exitErr diff --git a/vram.go b/vram.go index 954ca9a..9f6d1da 100644 --- a/vram.go +++ b/vram.go @@ -3,11 +3,12 @@ package rocm import ( - "fmt" "os" "path/filepath" "strconv" "strings" + + coreerr "forge.lthn.ai/core/go-log" ) // GetVRAMInfo reads VRAM usage for the discrete GPU from sysfs. @@ -19,10 +20,10 @@ import ( func GetVRAMInfo() (VRAMInfo, error) { cards, err := filepath.Glob("/sys/class/drm/card[0-9]*/device/mem_info_vram_total") if err != nil { - return VRAMInfo{}, fmt.Errorf("rocm: glob vram sysfs: %w", err) + return VRAMInfo{}, coreerr.E("rocm.GetVRAMInfo", "glob vram sysfs", err) } if len(cards) == 0 { - return VRAMInfo{}, fmt.Errorf("rocm: no GPU VRAM info found in sysfs") + return VRAMInfo{}, coreerr.E("rocm.GetVRAMInfo", "no GPU VRAM info found in sysfs", nil) } var bestDir string @@ -40,12 +41,12 @@ func GetVRAMInfo() (VRAMInfo, error) { } if bestDir == "" { - return VRAMInfo{}, fmt.Errorf("rocm: no readable VRAM sysfs entries") + return VRAMInfo{}, coreerr.E("rocm.GetVRAMInfo", "no readable VRAM sysfs entries", nil) } used, err := readSysfsUint64(filepath.Join(bestDir, "mem_info_vram_used")) if err != nil { - return VRAMInfo{}, fmt.Errorf("rocm: read vram used: %w", err) + return VRAMInfo{}, coreerr.E("rocm.GetVRAMInfo", "read vram used", err) } free := uint64(0)