refactor: replace fmt.Errorf/errors.New with coreerr.E()
Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
parent
c0b7485129
commit
4669cc503d
10 changed files with 74 additions and 60 deletions
|
|
@ -3,10 +3,10 @@
|
||||||
package rocm
|
package rocm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
coreerr "forge.lthn.ai/core/go-log"
|
||||||
"forge.lthn.ai/core/go-inference"
|
"forge.lthn.ai/core/go-inference"
|
||||||
"forge.lthn.ai/core/go-rocm/internal/gguf"
|
"forge.lthn.ai/core/go-rocm/internal/gguf"
|
||||||
)
|
)
|
||||||
|
|
@ -42,7 +42,7 @@ func (b *rocmBackend) LoadModel(path string, opts ...inference.LoadOption) (infe
|
||||||
|
|
||||||
meta, err := gguf.ReadMetadata(path)
|
meta, err := gguf.ReadMetadata(path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("rocm: read model metadata: %w", err)
|
return nil, coreerr.E("rocm.LoadModel", "read model metadata", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
ctxLen := cfg.ContextLen
|
ctxLen := cfg.ContextLen
|
||||||
|
|
|
||||||
5
go.mod
5
go.mod
|
|
@ -4,7 +4,10 @@ go 1.26.0
|
||||||
|
|
||||||
require forge.lthn.ai/core/go-inference v0.0.0
|
require forge.lthn.ai/core/go-inference v0.0.0
|
||||||
|
|
||||||
require github.com/kr/text v0.2.0 // indirect
|
require (
|
||||||
|
forge.lthn.ai/core/go-log v0.0.4 // indirect
|
||||||
|
github.com/kr/text v0.2.0 // indirect
|
||||||
|
)
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
|
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
|
||||||
|
|
|
||||||
2
go.sum
2
go.sum
|
|
@ -1,3 +1,5 @@
|
||||||
|
forge.lthn.ai/core/go-log v0.0.4 h1:KTuCEPgFmuM8KJfnyQ8vPOU1Jg654W74h8IJvfQMfv0=
|
||||||
|
forge.lthn.ai/core/go-log v0.0.4/go.mod h1:r14MXKOD3LF/sI8XUJQhRk/SZHBE7jAFVuCfgkXoZPw=
|
||||||
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
|
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
|
||||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
|
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
|
||||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
|
|
|
||||||
|
|
@ -15,6 +15,8 @@ import (
|
||||||
"math"
|
"math"
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
coreerr "forge.lthn.ai/core/go-log"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ggufMagic is the GGUF file magic number: "GGUF" in little-endian.
|
// ggufMagic is the GGUF file magic number: "GGUF" in little-endian.
|
||||||
|
|
@ -96,37 +98,37 @@ func ReadMetadata(path string) (Metadata, error) {
|
||||||
// Read and validate magic number.
|
// Read and validate magic number.
|
||||||
var magic uint32
|
var magic uint32
|
||||||
if err := binary.Read(r, binary.LittleEndian, &magic); err != nil {
|
if err := binary.Read(r, binary.LittleEndian, &magic); err != nil {
|
||||||
return Metadata{}, fmt.Errorf("reading magic: %w", err)
|
return Metadata{}, coreerr.E("gguf.ReadMetadata", "reading magic", err)
|
||||||
}
|
}
|
||||||
if magic != ggufMagic {
|
if magic != ggufMagic {
|
||||||
return Metadata{}, fmt.Errorf("invalid magic: 0x%08X (expected 0x%08X)", magic, ggufMagic)
|
return Metadata{}, coreerr.E("gguf.ReadMetadata", fmt.Sprintf("invalid magic: 0x%08X (expected 0x%08X)", magic, ggufMagic), nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Read version.
|
// Read version.
|
||||||
var version uint32
|
var version uint32
|
||||||
if err := binary.Read(r, binary.LittleEndian, &version); err != nil {
|
if err := binary.Read(r, binary.LittleEndian, &version); err != nil {
|
||||||
return Metadata{}, fmt.Errorf("reading version: %w", err)
|
return Metadata{}, coreerr.E("gguf.ReadMetadata", "reading version", err)
|
||||||
}
|
}
|
||||||
if version < 2 || version > 3 {
|
if version < 2 || version > 3 {
|
||||||
return Metadata{}, fmt.Errorf("unsupported GGUF version: %d", version)
|
return Metadata{}, coreerr.E("gguf.ReadMetadata", fmt.Sprintf("unsupported GGUF version: %d", version), nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Read tensor count and KV count. v3 uses uint64, v2 uses uint32.
|
// Read tensor count and KV count. v3 uses uint64, v2 uses uint32.
|
||||||
var tensorCount, kvCount uint64
|
var tensorCount, kvCount uint64
|
||||||
if version == 3 {
|
if version == 3 {
|
||||||
if err := binary.Read(r, binary.LittleEndian, &tensorCount); err != nil {
|
if err := binary.Read(r, binary.LittleEndian, &tensorCount); err != nil {
|
||||||
return Metadata{}, fmt.Errorf("reading tensor count: %w", err)
|
return Metadata{}, coreerr.E("gguf.ReadMetadata", "reading tensor count", err)
|
||||||
}
|
}
|
||||||
if err := binary.Read(r, binary.LittleEndian, &kvCount); err != nil {
|
if err := binary.Read(r, binary.LittleEndian, &kvCount); err != nil {
|
||||||
return Metadata{}, fmt.Errorf("reading kv count: %w", err)
|
return Metadata{}, coreerr.E("gguf.ReadMetadata", "reading kv count", err)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
var tc, kc uint32
|
var tc, kc uint32
|
||||||
if err := binary.Read(r, binary.LittleEndian, &tc); err != nil {
|
if err := binary.Read(r, binary.LittleEndian, &tc); err != nil {
|
||||||
return Metadata{}, fmt.Errorf("reading tensor count: %w", err)
|
return Metadata{}, coreerr.E("gguf.ReadMetadata", "reading tensor count", err)
|
||||||
}
|
}
|
||||||
if err := binary.Read(r, binary.LittleEndian, &kc); err != nil {
|
if err := binary.Read(r, binary.LittleEndian, &kc); err != nil {
|
||||||
return Metadata{}, fmt.Errorf("reading kv count: %w", err)
|
return Metadata{}, coreerr.E("gguf.ReadMetadata", "reading kv count", err)
|
||||||
}
|
}
|
||||||
tensorCount = uint64(tc)
|
tensorCount = uint64(tc)
|
||||||
kvCount = uint64(kc)
|
kvCount = uint64(kc)
|
||||||
|
|
@ -148,12 +150,12 @@ func ReadMetadata(path string) (Metadata, error) {
|
||||||
for i := uint64(0); i < kvCount; i++ {
|
for i := uint64(0); i < kvCount; i++ {
|
||||||
key, err := readString(r)
|
key, err := readString(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return Metadata{}, fmt.Errorf("reading key %d: %w", i, err)
|
return Metadata{}, coreerr.E("gguf.ReadMetadata", fmt.Sprintf("reading key %d", i), err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var valType uint32
|
var valType uint32
|
||||||
if err := binary.Read(r, binary.LittleEndian, &valType); err != nil {
|
if err := binary.Read(r, binary.LittleEndian, &valType); err != nil {
|
||||||
return Metadata{}, fmt.Errorf("reading value type for key %q: %w", key, err)
|
return Metadata{}, coreerr.E("gguf.ReadMetadata", fmt.Sprintf("reading value type for key %q", key), err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check whether this is an interesting key before reading the value.
|
// Check whether this is an interesting key before reading the value.
|
||||||
|
|
@ -161,7 +163,7 @@ func ReadMetadata(path string) (Metadata, error) {
|
||||||
case key == "general.architecture":
|
case key == "general.architecture":
|
||||||
v, err := readTypedValue(r, valType)
|
v, err := readTypedValue(r, valType)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return Metadata{}, fmt.Errorf("reading value for key %q: %w", key, err)
|
return Metadata{}, coreerr.E("gguf.ReadMetadata", fmt.Sprintf("reading value for key %q", key), err)
|
||||||
}
|
}
|
||||||
if s, ok := v.(string); ok {
|
if s, ok := v.(string); ok {
|
||||||
meta.Architecture = s
|
meta.Architecture = s
|
||||||
|
|
@ -170,7 +172,7 @@ func ReadMetadata(path string) (Metadata, error) {
|
||||||
case key == "general.name":
|
case key == "general.name":
|
||||||
v, err := readTypedValue(r, valType)
|
v, err := readTypedValue(r, valType)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return Metadata{}, fmt.Errorf("reading value for key %q: %w", key, err)
|
return Metadata{}, coreerr.E("gguf.ReadMetadata", fmt.Sprintf("reading value for key %q", key), err)
|
||||||
}
|
}
|
||||||
if s, ok := v.(string); ok {
|
if s, ok := v.(string); ok {
|
||||||
meta.Name = s
|
meta.Name = s
|
||||||
|
|
@ -179,7 +181,7 @@ func ReadMetadata(path string) (Metadata, error) {
|
||||||
case key == "general.file_type":
|
case key == "general.file_type":
|
||||||
v, err := readTypedValue(r, valType)
|
v, err := readTypedValue(r, valType)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return Metadata{}, fmt.Errorf("reading value for key %q: %w", key, err)
|
return Metadata{}, coreerr.E("gguf.ReadMetadata", fmt.Sprintf("reading value for key %q", key), err)
|
||||||
}
|
}
|
||||||
if u, ok := v.(uint32); ok {
|
if u, ok := v.(uint32); ok {
|
||||||
meta.FileType = u
|
meta.FileType = u
|
||||||
|
|
@ -188,7 +190,7 @@ func ReadMetadata(path string) (Metadata, error) {
|
||||||
case key == "general.size_label":
|
case key == "general.size_label":
|
||||||
v, err := readTypedValue(r, valType)
|
v, err := readTypedValue(r, valType)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return Metadata{}, fmt.Errorf("reading value for key %q: %w", key, err)
|
return Metadata{}, coreerr.E("gguf.ReadMetadata", fmt.Sprintf("reading value for key %q", key), err)
|
||||||
}
|
}
|
||||||
if s, ok := v.(string); ok {
|
if s, ok := v.(string); ok {
|
||||||
meta.SizeLabel = s
|
meta.SizeLabel = s
|
||||||
|
|
@ -197,7 +199,7 @@ func ReadMetadata(path string) (Metadata, error) {
|
||||||
case strings.HasSuffix(key, ".context_length"):
|
case strings.HasSuffix(key, ".context_length"):
|
||||||
v, err := readTypedValue(r, valType)
|
v, err := readTypedValue(r, valType)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return Metadata{}, fmt.Errorf("reading value for key %q: %w", key, err)
|
return Metadata{}, coreerr.E("gguf.ReadMetadata", fmt.Sprintf("reading value for key %q", key), err)
|
||||||
}
|
}
|
||||||
if u, ok := v.(uint32); ok {
|
if u, ok := v.(uint32); ok {
|
||||||
candidateContextLength[key] = u
|
candidateContextLength[key] = u
|
||||||
|
|
@ -206,7 +208,7 @@ func ReadMetadata(path string) (Metadata, error) {
|
||||||
case strings.HasSuffix(key, ".block_count"):
|
case strings.HasSuffix(key, ".block_count"):
|
||||||
v, err := readTypedValue(r, valType)
|
v, err := readTypedValue(r, valType)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return Metadata{}, fmt.Errorf("reading value for key %q: %w", key, err)
|
return Metadata{}, coreerr.E("gguf.ReadMetadata", fmt.Sprintf("reading value for key %q", key), err)
|
||||||
}
|
}
|
||||||
if u, ok := v.(uint32); ok {
|
if u, ok := v.(uint32); ok {
|
||||||
candidateBlockCount[key] = u
|
candidateBlockCount[key] = u
|
||||||
|
|
@ -215,7 +217,7 @@ func ReadMetadata(path string) (Metadata, error) {
|
||||||
default:
|
default:
|
||||||
// Skip uninteresting value.
|
// Skip uninteresting value.
|
||||||
if err := skipValue(r, valType); err != nil {
|
if err := skipValue(r, valType); err != nil {
|
||||||
return Metadata{}, fmt.Errorf("skipping value for key %q: %w", key, err)
|
return Metadata{}, coreerr.E("gguf.ReadMetadata", fmt.Sprintf("skipping value for key %q", key), err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -245,7 +247,7 @@ func readString(r io.Reader) (string, error) {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
if length > maxStringLength {
|
if length > maxStringLength {
|
||||||
return "", fmt.Errorf("string length %d exceeds maximum %d", length, maxStringLength)
|
return "", coreerr.E("gguf.readString", fmt.Sprintf("string length %d exceeds maximum %d", length, maxStringLength), nil)
|
||||||
}
|
}
|
||||||
buf := make([]byte, length)
|
buf := make([]byte, length)
|
||||||
if _, err := io.ReadFull(r, buf); err != nil {
|
if _, err := io.ReadFull(r, buf); err != nil {
|
||||||
|
|
@ -302,7 +304,7 @@ func skipValue(r io.Reader, valType uint32) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if length > maxStringLength {
|
if length > maxStringLength {
|
||||||
return fmt.Errorf("string length %d exceeds maximum %d", length, maxStringLength)
|
return coreerr.E("gguf.skipValue", fmt.Sprintf("string length %d exceeds maximum %d", length, maxStringLength), nil)
|
||||||
}
|
}
|
||||||
_, err := readN(r, int64(length))
|
_, err := readN(r, int64(length))
|
||||||
return err
|
return err
|
||||||
|
|
@ -322,7 +324,7 @@ func skipValue(r io.Reader, valType uint32) error {
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("unknown GGUF value type: %d", valType)
|
return coreerr.E("gguf.skipValue", fmt.Sprintf("unknown GGUF value type: %d", valType), nil)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -11,6 +11,8 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
|
coreerr "forge.lthn.ai/core/go-log"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ChatMessage is a single message in a conversation.
|
// ChatMessage is a single message in a conversation.
|
||||||
|
|
@ -65,26 +67,26 @@ func (c *Client) ChatComplete(ctx context.Context, req ChatRequest) (iter.Seq[st
|
||||||
|
|
||||||
body, err := json.Marshal(req)
|
body, err := json.Marshal(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return noChunks, func() error { return fmt.Errorf("llamacpp: marshal chat request: %w", err) }
|
return noChunks, func() error { return coreerr.E("llamacpp.ChatComplete", "marshal chat request", err) }
|
||||||
}
|
}
|
||||||
|
|
||||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+"/v1/chat/completions", bytes.NewReader(body))
|
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+"/v1/chat/completions", bytes.NewReader(body))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return noChunks, func() error { return fmt.Errorf("llamacpp: create chat request: %w", err) }
|
return noChunks, func() error { return coreerr.E("llamacpp.ChatComplete", "create chat request", err) }
|
||||||
}
|
}
|
||||||
httpReq.Header.Set("Content-Type", "application/json")
|
httpReq.Header.Set("Content-Type", "application/json")
|
||||||
httpReq.Header.Set("Accept", "text/event-stream")
|
httpReq.Header.Set("Accept", "text/event-stream")
|
||||||
|
|
||||||
resp, err := c.httpClient.Do(httpReq)
|
resp, err := c.httpClient.Do(httpReq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return noChunks, func() error { return fmt.Errorf("llamacpp: chat request: %w", err) }
|
return noChunks, func() error { return coreerr.E("llamacpp.ChatComplete", "chat request", err) }
|
||||||
}
|
}
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 256))
|
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 256))
|
||||||
return noChunks, func() error {
|
return noChunks, func() error {
|
||||||
return fmt.Errorf("llamacpp: chat returned %d: %s", resp.StatusCode, strings.TrimSpace(string(respBody)))
|
return coreerr.E("llamacpp.ChatComplete", fmt.Sprintf("chat returned %d: %s", resp.StatusCode, strings.TrimSpace(string(respBody))), nil)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -100,7 +102,7 @@ func (c *Client) ChatComplete(ctx context.Context, req ChatRequest) (iter.Seq[st
|
||||||
for raw := range sseData {
|
for raw := range sseData {
|
||||||
var chunk chatChunkResponse
|
var chunk chatChunkResponse
|
||||||
if err := json.Unmarshal([]byte(raw), &chunk); err != nil {
|
if err := json.Unmarshal([]byte(raw), &chunk); err != nil {
|
||||||
streamErr = fmt.Errorf("llamacpp: decode chat chunk: %w", err)
|
streamErr = coreerr.E("llamacpp.ChatComplete", "decode chat chunk", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if len(chunk.Choices) == 0 {
|
if len(chunk.Choices) == 0 {
|
||||||
|
|
@ -130,26 +132,26 @@ func (c *Client) Complete(ctx context.Context, req CompletionRequest) (iter.Seq[
|
||||||
|
|
||||||
body, err := json.Marshal(req)
|
body, err := json.Marshal(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return noChunks, func() error { return fmt.Errorf("llamacpp: marshal completion request: %w", err) }
|
return noChunks, func() error { return coreerr.E("llamacpp.Complete", "marshal completion request", err) }
|
||||||
}
|
}
|
||||||
|
|
||||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+"/v1/completions", bytes.NewReader(body))
|
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+"/v1/completions", bytes.NewReader(body))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return noChunks, func() error { return fmt.Errorf("llamacpp: create completion request: %w", err) }
|
return noChunks, func() error { return coreerr.E("llamacpp.Complete", "create completion request", err) }
|
||||||
}
|
}
|
||||||
httpReq.Header.Set("Content-Type", "application/json")
|
httpReq.Header.Set("Content-Type", "application/json")
|
||||||
httpReq.Header.Set("Accept", "text/event-stream")
|
httpReq.Header.Set("Accept", "text/event-stream")
|
||||||
|
|
||||||
resp, err := c.httpClient.Do(httpReq)
|
resp, err := c.httpClient.Do(httpReq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return noChunks, func() error { return fmt.Errorf("llamacpp: completion request: %w", err) }
|
return noChunks, func() error { return coreerr.E("llamacpp.Complete", "completion request", err) }
|
||||||
}
|
}
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 256))
|
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 256))
|
||||||
return noChunks, func() error {
|
return noChunks, func() error {
|
||||||
return fmt.Errorf("llamacpp: completion returned %d: %s", resp.StatusCode, strings.TrimSpace(string(respBody)))
|
return coreerr.E("llamacpp.Complete", fmt.Sprintf("completion returned %d: %s", resp.StatusCode, strings.TrimSpace(string(respBody))), nil)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -165,7 +167,7 @@ func (c *Client) Complete(ctx context.Context, req CompletionRequest) (iter.Seq[
|
||||||
for raw := range sseData {
|
for raw := range sseData {
|
||||||
var chunk completionChunkResponse
|
var chunk completionChunkResponse
|
||||||
if err := json.Unmarshal([]byte(raw), &chunk); err != nil {
|
if err := json.Unmarshal([]byte(raw), &chunk); err != nil {
|
||||||
streamErr = fmt.Errorf("llamacpp: decode completion chunk: %w", err)
|
streamErr = coreerr.E("llamacpp.Complete", "decode completion chunk", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if len(chunk.Choices) == 0 {
|
if len(chunk.Choices) == 0 {
|
||||||
|
|
@ -207,7 +209,7 @@ func parseSSE(r io.Reader, errOut *error) iter.Seq[string] {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if err := scanner.Err(); err != nil {
|
if err := scanner.Err(); err != nil {
|
||||||
*errOut = fmt.Errorf("llamacpp: read SSE stream: %w", err)
|
*errOut = coreerr.E("llamacpp.parseSSE", "read SSE stream", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,8 @@ import (
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
coreerr "forge.lthn.ai/core/go-log"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Client communicates with a llama-server instance.
|
// Client communicates with a llama-server instance.
|
||||||
|
|
@ -41,14 +43,14 @@ func (c *Client) Health(ctx context.Context) error {
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 256))
|
body, _ := io.ReadAll(io.LimitReader(resp.Body, 256))
|
||||||
return fmt.Errorf("llamacpp: health returned %d: %s", resp.StatusCode, string(body))
|
return coreerr.E("llamacpp.Health", fmt.Sprintf("health returned %d: %s", resp.StatusCode, string(body)), nil)
|
||||||
}
|
}
|
||||||
var h healthResponse
|
var h healthResponse
|
||||||
if err := json.NewDecoder(resp.Body).Decode(&h); err != nil {
|
if err := json.NewDecoder(resp.Body).Decode(&h); err != nil {
|
||||||
return fmt.Errorf("llamacpp: health decode: %w", err)
|
return coreerr.E("llamacpp.Health", "health decode", err)
|
||||||
}
|
}
|
||||||
if h.Status != "ok" {
|
if h.Status != "ok" {
|
||||||
return fmt.Errorf("llamacpp: server not ready (status: %s)", h.Status)
|
return coreerr.E("llamacpp.Health", fmt.Sprintf("server not ready (status: %s)", h.Status), nil)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
||||||
9
model.go
9
model.go
|
|
@ -10,6 +10,7 @@ import (
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
coreerr "forge.lthn.ai/core/go-log"
|
||||||
"forge.lthn.ai/core/go-inference"
|
"forge.lthn.ai/core/go-inference"
|
||||||
"forge.lthn.ai/core/go-rocm/internal/llamacpp"
|
"forge.lthn.ai/core/go-rocm/internal/llamacpp"
|
||||||
)
|
)
|
||||||
|
|
@ -148,7 +149,7 @@ func (m *rocmModel) Classify(ctx context.Context, prompts []string, opts ...infe
|
||||||
text.WriteString(chunk)
|
text.WriteString(chunk)
|
||||||
}
|
}
|
||||||
if err := errFn(); err != nil {
|
if err := errFn(); err != nil {
|
||||||
return nil, fmt.Errorf("rocm: classify prompt %d: %w", i, err)
|
return nil, coreerr.E("rocm.Classify", fmt.Sprintf("classify prompt %d", i), err)
|
||||||
}
|
}
|
||||||
|
|
||||||
results[i] = inference.ClassifyResult{
|
results[i] = inference.ClassifyResult{
|
||||||
|
|
@ -194,7 +195,7 @@ func (m *rocmModel) BatchGenerate(ctx context.Context, prompts []string, opts ..
|
||||||
tokens = append(tokens, inference.Token{Text: text})
|
tokens = append(tokens, inference.Token{Text: text})
|
||||||
}
|
}
|
||||||
if err := errFn(); err != nil {
|
if err := errFn(); err != nil {
|
||||||
results[i].Err = fmt.Errorf("rocm: batch prompt %d: %w", i, err)
|
results[i].Err = coreerr.E("rocm.BatchGenerate", fmt.Sprintf("batch prompt %d", i), err)
|
||||||
}
|
}
|
||||||
results[i].Tokens = tokens
|
results[i].Tokens = tokens
|
||||||
totalGenerated += len(tokens)
|
totalGenerated += len(tokens)
|
||||||
|
|
@ -234,9 +235,9 @@ func (m *rocmModel) setServerExitErr() {
|
||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
defer m.mu.Unlock()
|
defer m.mu.Unlock()
|
||||||
if m.srv.exitErr != nil {
|
if m.srv.exitErr != nil {
|
||||||
m.lastErr = fmt.Errorf("rocm: server has exited: %w", m.srv.exitErr)
|
m.lastErr = coreerr.E("rocm.setServerExitErr", "server has exited", m.srv.exitErr)
|
||||||
} else {
|
} else {
|
||||||
m.lastErr = fmt.Errorf("rocm: server has exited unexpectedly")
|
m.lastErr = coreerr.E("rocm.setServerExitErr", "server has exited unexpectedly", nil)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@
|
||||||
|
|
||||||
package rocm
|
package rocm
|
||||||
|
|
||||||
import "fmt"
|
import coreerr "forge.lthn.ai/core/go-log"
|
||||||
|
|
||||||
// ROCmAvailable reports whether ROCm GPU inference is available.
|
// ROCmAvailable reports whether ROCm GPU inference is available.
|
||||||
// Returns false on non-Linux or non-amd64 platforms.
|
// Returns false on non-Linux or non-amd64 platforms.
|
||||||
|
|
@ -10,5 +10,5 @@ func ROCmAvailable() bool { return false }
|
||||||
|
|
||||||
// GetVRAMInfo is not available on non-Linux/non-amd64 platforms.
|
// GetVRAMInfo is not available on non-Linux/non-amd64 platforms.
|
||||||
func GetVRAMInfo() (VRAMInfo, error) {
|
func GetVRAMInfo() (VRAMInfo, error) {
|
||||||
return VRAMInfo{}, fmt.Errorf("rocm: VRAM monitoring not available on this platform")
|
return VRAMInfo{}, coreerr.E("rocm.GetVRAMInfo", "VRAM monitoring not available on this platform", nil)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
25
server.go
25
server.go
|
|
@ -13,6 +13,7 @@ import (
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
coreerr "forge.lthn.ai/core/go-log"
|
||||||
"forge.lthn.ai/core/go-rocm/internal/llamacpp"
|
"forge.lthn.ai/core/go-rocm/internal/llamacpp"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -40,13 +41,13 @@ func (s *server) alive() bool {
|
||||||
func findLlamaServer() (string, error) {
|
func findLlamaServer() (string, error) {
|
||||||
if p := os.Getenv("ROCM_LLAMA_SERVER_PATH"); p != "" {
|
if p := os.Getenv("ROCM_LLAMA_SERVER_PATH"); p != "" {
|
||||||
if _, err := os.Stat(p); err != nil {
|
if _, err := os.Stat(p); err != nil {
|
||||||
return "", fmt.Errorf("llama-server not found at ROCM_LLAMA_SERVER_PATH=%s: %w", p, err)
|
return "", coreerr.E("rocm.findLlamaServer", "llama-server not found at ROCM_LLAMA_SERVER_PATH="+p, err)
|
||||||
}
|
}
|
||||||
return p, nil
|
return p, nil
|
||||||
}
|
}
|
||||||
p, err := exec.LookPath("llama-server")
|
p, err := exec.LookPath("llama-server")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("llama-server not found in PATH: %w", err)
|
return "", coreerr.E("rocm.findLlamaServer", "llama-server not found in PATH", err)
|
||||||
}
|
}
|
||||||
return p, nil
|
return p, nil
|
||||||
}
|
}
|
||||||
|
|
@ -55,7 +56,7 @@ func findLlamaServer() (string, error) {
|
||||||
func freePort() (int, error) {
|
func freePort() (int, error) {
|
||||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, fmt.Errorf("freePort: %w", err)
|
return 0, coreerr.E("rocm.freePort", "listen for free port", err)
|
||||||
}
|
}
|
||||||
port := ln.Addr().(*net.TCPAddr).Port
|
port := ln.Addr().(*net.TCPAddr).Port
|
||||||
ln.Close()
|
ln.Close()
|
||||||
|
|
@ -92,7 +93,7 @@ func startServer(binary, modelPath string, gpuLayers, ctxSize, parallelSlots int
|
||||||
for attempt := range maxAttempts {
|
for attempt := range maxAttempts {
|
||||||
port, err := freePort()
|
port, err := freePort()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("rocm: find free port: %w", err)
|
return nil, coreerr.E("rocm.startServer", "find free port", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
args := []string{
|
args := []string{
|
||||||
|
|
@ -112,7 +113,7 @@ func startServer(binary, modelPath string, gpuLayers, ctxSize, parallelSlots int
|
||||||
cmd.Env = serverEnv()
|
cmd.Env = serverEnv()
|
||||||
|
|
||||||
if err := cmd.Start(); err != nil {
|
if err := cmd.Start(); err != nil {
|
||||||
return nil, fmt.Errorf("start llama-server: %w", err)
|
return nil, coreerr.E("rocm.startServer", "start llama-server", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
s := &server{
|
s := &server{
|
||||||
|
|
@ -139,15 +140,15 @@ func startServer(binary, modelPath string, gpuLayers, ctxSize, parallelSlots int
|
||||||
select {
|
select {
|
||||||
case <-s.exited:
|
case <-s.exited:
|
||||||
_ = s.stop()
|
_ = s.stop()
|
||||||
lastErr = fmt.Errorf("attempt %d: %w", attempt+1, err)
|
lastErr = coreerr.E("rocm.startServer", fmt.Sprintf("attempt %d", attempt+1), err)
|
||||||
continue
|
continue
|
||||||
default:
|
default:
|
||||||
_ = s.stop()
|
_ = s.stop()
|
||||||
return nil, fmt.Errorf("rocm: llama-server not ready: %w", err)
|
return nil, coreerr.E("rocm.startServer", "llama-server not ready", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, fmt.Errorf("rocm: server failed after %d attempts: %w", maxAttempts, lastErr)
|
return nil, coreerr.E("rocm.startServer", fmt.Sprintf("server failed after %d attempts", maxAttempts), lastErr)
|
||||||
}
|
}
|
||||||
|
|
||||||
// waitReady polls the health endpoint until the server is ready.
|
// waitReady polls the health endpoint until the server is ready.
|
||||||
|
|
@ -158,9 +159,9 @@ func (s *server) waitReady(ctx context.Context) error {
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return fmt.Errorf("timeout waiting for llama-server: %w", ctx.Err())
|
return coreerr.E("server.waitReady", "timeout waiting for llama-server", ctx.Err())
|
||||||
case <-s.exited:
|
case <-s.exited:
|
||||||
return fmt.Errorf("llama-server exited before becoming ready: %v", s.exitErr)
|
return coreerr.E("server.waitReady", "llama-server exited before becoming ready", s.exitErr)
|
||||||
case <-ticker.C:
|
case <-ticker.C:
|
||||||
if err := s.client.Health(ctx); err == nil {
|
if err := s.client.Health(ctx); err == nil {
|
||||||
return nil
|
return nil
|
||||||
|
|
@ -184,7 +185,7 @@ func (s *server) stop() error {
|
||||||
|
|
||||||
// Send SIGTERM for graceful shutdown.
|
// Send SIGTERM for graceful shutdown.
|
||||||
if err := s.cmd.Process.Signal(syscall.SIGTERM); err != nil {
|
if err := s.cmd.Process.Signal(syscall.SIGTERM); err != nil {
|
||||||
return fmt.Errorf("sigterm llama-server: %w", err)
|
return coreerr.E("server.stop", "sigterm llama-server", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Wait up to 5 seconds for clean exit.
|
// Wait up to 5 seconds for clean exit.
|
||||||
|
|
@ -194,7 +195,7 @@ func (s *server) stop() error {
|
||||||
case <-time.After(5 * time.Second):
|
case <-time.After(5 * time.Second):
|
||||||
// Force kill.
|
// Force kill.
|
||||||
if err := s.cmd.Process.Kill(); err != nil {
|
if err := s.cmd.Process.Kill(); err != nil {
|
||||||
return fmt.Errorf("kill llama-server: %w", err)
|
return coreerr.E("server.stop", "kill llama-server", err)
|
||||||
}
|
}
|
||||||
<-s.exited
|
<-s.exited
|
||||||
return s.exitErr
|
return s.exitErr
|
||||||
|
|
|
||||||
11
vram.go
11
vram.go
|
|
@ -3,11 +3,12 @@
|
||||||
package rocm
|
package rocm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
coreerr "forge.lthn.ai/core/go-log"
|
||||||
)
|
)
|
||||||
|
|
||||||
// GetVRAMInfo reads VRAM usage for the discrete GPU from sysfs.
|
// GetVRAMInfo reads VRAM usage for the discrete GPU from sysfs.
|
||||||
|
|
@ -19,10 +20,10 @@ import (
|
||||||
func GetVRAMInfo() (VRAMInfo, error) {
|
func GetVRAMInfo() (VRAMInfo, error) {
|
||||||
cards, err := filepath.Glob("/sys/class/drm/card[0-9]*/device/mem_info_vram_total")
|
cards, err := filepath.Glob("/sys/class/drm/card[0-9]*/device/mem_info_vram_total")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return VRAMInfo{}, fmt.Errorf("rocm: glob vram sysfs: %w", err)
|
return VRAMInfo{}, coreerr.E("rocm.GetVRAMInfo", "glob vram sysfs", err)
|
||||||
}
|
}
|
||||||
if len(cards) == 0 {
|
if len(cards) == 0 {
|
||||||
return VRAMInfo{}, fmt.Errorf("rocm: no GPU VRAM info found in sysfs")
|
return VRAMInfo{}, coreerr.E("rocm.GetVRAMInfo", "no GPU VRAM info found in sysfs", nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
var bestDir string
|
var bestDir string
|
||||||
|
|
@ -40,12 +41,12 @@ func GetVRAMInfo() (VRAMInfo, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if bestDir == "" {
|
if bestDir == "" {
|
||||||
return VRAMInfo{}, fmt.Errorf("rocm: no readable VRAM sysfs entries")
|
return VRAMInfo{}, coreerr.E("rocm.GetVRAMInfo", "no readable VRAM sysfs entries", nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
used, err := readSysfsUint64(filepath.Join(bestDir, "mem_info_vram_used"))
|
used, err := readSysfsUint64(filepath.Join(bestDir, "mem_info_vram_used"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return VRAMInfo{}, fmt.Errorf("rocm: read vram used: %w", err)
|
return VRAMInfo{}, coreerr.E("rocm.GetVRAMInfo", "read vram used", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
free := uint64(0)
|
free := uint64(0)
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue