diff --git a/backend.go b/backend.go index dde011a..695e6de 100644 --- a/backend.go +++ b/backend.go @@ -3,8 +3,7 @@ package rocm import ( - "os" - "strings" + "dappco.re/go/core" coreerr "forge.lthn.ai/core/go-log" "forge.lthn.ai/core/go-inference" @@ -22,7 +21,7 @@ func (b *rocmBackend) Name() string { return "rocm" } // b := inference.FindBackend("rocm") // if b.Available() { /* safe to LoadModel */ } func (b *rocmBackend) Available() bool { - if _, err := os.Stat("/dev/kfd"); err != nil { + if !(&core.Fs{}).New("/").Exists("/dev/kfd") { return false } if _, err := findLlamaServer(); err != nil { @@ -42,7 +41,7 @@ func (b *rocmBackend) Available() bool { // ) // defer m.Close() func (b *rocmBackend) LoadModel(path string, opts ...inference.LoadOption) (inference.TextModel, error) { - config := inference.ApplyLoadOpts(opts) + configuration := inference.ApplyLoadOpts(opts) binary, err := findLlamaServer() if err != nil { @@ -54,12 +53,12 @@ func (b *rocmBackend) LoadModel(path string, opts ...inference.LoadOption) (infe return nil, coreerr.E("rocm.LoadModel", "read model metadata", err) } - ctxLen := config.ContextLen - if ctxLen == 0 && meta.ContextLength > 0 { - ctxLen = int(min(meta.ContextLength, 4096)) + contextLen := configuration.ContextLen + if contextLen == 0 && meta.ContextLength > 0 { + contextLen = int(min(meta.ContextLength, 4096)) } - srv, err := startServer(binary, path, config.GPULayers, ctxLen, config.ParallelSlots) + subprocess, err := startServer(binary, path, configuration.GPULayers, contextLen, configuration.ParallelSlots) if err != nil { return nil, err } @@ -67,34 +66,34 @@ func (b *rocmBackend) LoadModel(path string, opts ...inference.LoadOption) (infe // Map quantisation file type to bit width. quantBits := 0 quantGroup := 0 - ftName := gguf.FileTypeName(meta.FileType) + fileTypeName := gguf.FileTypeName(meta.FileType) switch { - case strings.HasPrefix(ftName, "Q4_"): + case core.HasPrefix(fileTypeName, "Q4_"): quantBits = 4 quantGroup = 32 - case strings.HasPrefix(ftName, "Q5_"): + case core.HasPrefix(fileTypeName, "Q5_"): quantBits = 5 quantGroup = 32 - case strings.HasPrefix(ftName, "Q8_"): + case core.HasPrefix(fileTypeName, "Q8_"): quantBits = 8 quantGroup = 32 - case strings.HasPrefix(ftName, "Q2_"): + case core.HasPrefix(fileTypeName, "Q2_"): quantBits = 2 quantGroup = 16 - case strings.HasPrefix(ftName, "Q3_"): + case core.HasPrefix(fileTypeName, "Q3_"): quantBits = 3 quantGroup = 32 - case strings.HasPrefix(ftName, "Q6_"): + case core.HasPrefix(fileTypeName, "Q6_"): quantBits = 6 quantGroup = 64 - case ftName == "F16": + case fileTypeName == "F16": quantBits = 16 - case ftName == "F32": + case fileTypeName == "F32": quantBits = 32 } return &rocmModel{ - server: srv, + server: subprocess, modelType: meta.Architecture, modelInfo: inference.ModelInfo{ Architecture: meta.Architecture, diff --git a/discover.go b/discover.go index 5b23482..44b0872 100644 --- a/discover.go +++ b/discover.go @@ -1,23 +1,18 @@ package rocm import ( - "path/filepath" + "dappco.re/go/core" "forge.lthn.ai/core/go-rocm/internal/gguf" ) -// DiscoverModels scans a directory for GGUF model files. -// Files that cannot be parsed are silently skipped. +// DiscoverModels scans a directory for GGUF model files and returns +// structured information about each. Files that cannot be parsed are skipped. // // models, err := rocm.DiscoverModels("/data/lem/gguf") -// for _, m := range models { -// fmt.Printf("%s: %s %s ctx=%d\n", m.Name, m.Architecture, m.Quantisation, m.ContextLen) -// } +// for _, m := range models { core.Print(c, "%s %s ctx=%d", m.Name, m.Quantisation, m.ContextLen) } func DiscoverModels(dir string) ([]ModelInfo, error) { - matches, err := filepath.Glob(filepath.Join(dir, "*.gguf")) - if err != nil { - return nil, err - } + matches := core.PathGlob(core.Path(dir, "*.gguf")) var models []ModelInfo for _, path := range matches { diff --git a/discover_test.go b/discover_test.go index 1f74a3c..2c8a5fe 100644 --- a/discover_test.go +++ b/discover_test.go @@ -3,9 +3,10 @@ package rocm import ( "encoding/binary" "os" - "path/filepath" "testing" + "dappco.re/go/core" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -15,7 +16,7 @@ import ( func writeDiscoverTestGGUF(t *testing.T, dir, filename string, kvs [][2]any) string { t.Helper() - path := filepath.Join(dir, filename) + path := core.Path(dir, filename) f, err := os.Create(path) require.NoError(t, err) @@ -86,7 +87,7 @@ func TestDiscoverModels_Good(t *testing.T) { }) // Create a non-GGUF file that should be ignored (no .gguf extension). - require.NoError(t, os.WriteFile(filepath.Join(dir, "README.txt"), []byte("not a model"), 0644)) + require.NoError(t, os.WriteFile(core.Path(dir, "README.txt"), []byte("not a model"), 0644)) models, err := DiscoverModels(dir) require.NoError(t, err) @@ -94,7 +95,7 @@ func TestDiscoverModels_Good(t *testing.T) { // Sort order from Glob is lexicographic, so gemma3 comes first. gemma := models[0] - assert.Equal(t, filepath.Join(dir, "gemma3-4b-q4km.gguf"), gemma.Path) + assert.Equal(t, core.Path(dir, "gemma3-4b-q4km.gguf"), gemma.Path) assert.Equal(t, "gemma3", gemma.Architecture) assert.Equal(t, "Gemma 3 4B Instruct", gemma.Name) assert.Equal(t, "Q4_K_M", gemma.Quantisation) @@ -103,7 +104,7 @@ func TestDiscoverModels_Good(t *testing.T) { assert.Greater(t, gemma.FileSize, int64(0)) llama := models[1] - assert.Equal(t, filepath.Join(dir, "llama-3.1-8b-q4km.gguf"), llama.Path) + assert.Equal(t, core.Path(dir, "llama-3.1-8b-q4km.gguf"), llama.Path) assert.Equal(t, "llama", llama.Architecture) assert.Equal(t, "Llama 3.1 8B Instruct", llama.Name) assert.Equal(t, "Q4_K_M", llama.Quantisation) @@ -120,8 +121,8 @@ func TestDiscoverModels_Good_EmptyDir(t *testing.T) { assert.Empty(t, models) } -func TestDiscoverModels_Good_NonExistentDir(t *testing.T) { - // filepath.Glob returns nil, nil for a pattern matching no files, +func TestDiscoverModels_Bad_NonExistentDir(t *testing.T) { + // core.PathGlob returns nil for patterns matching no files, // even when the directory does not exist. models, err := DiscoverModels("/nonexistent/dir") require.NoError(t, err) @@ -139,7 +140,7 @@ func TestDiscoverModels_Ugly_SkipsCorruptFile(t *testing.T) { }) // Create a corrupt .gguf file (not valid GGUF binary). - require.NoError(t, os.WriteFile(filepath.Join(dir, "corrupt.gguf"), []byte("not gguf data"), 0644)) + require.NoError(t, os.WriteFile(core.Path(dir, "corrupt.gguf"), []byte("not gguf data"), 0644)) models, err := DiscoverModels(dir) require.NoError(t, err) diff --git a/go.mod b/go.mod index 027a377..d5ac341 100644 --- a/go.mod +++ b/go.mod @@ -3,15 +3,11 @@ module forge.lthn.ai/core/go-rocm go 1.26.0 require ( + dappco.re/go/core v0.8.0-alpha.1 forge.lthn.ai/core/go-inference v0.1.5 forge.lthn.ai/core/go-log v0.0.4 ) -require ( - dappco.re/go/core v0.8.0-alpha.1 // indirect - github.com/kr/text v0.2.0 // indirect -) - require ( github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect diff --git a/go.sum b/go.sum index 1e66859..0c5f68d 100644 --- a/go.sum +++ b/go.sum @@ -2,7 +2,6 @@ dappco.re/go/core v0.8.0-alpha.1 h1:gj7+Scv+L63Z7wMxbJYHhaRFkHJo2u4MMPuUSv/Dhtk= dappco.re/go/core v0.8.0-alpha.1/go.mod h1:f2/tBZ3+3IqDrg2F5F598llv0nmb/4gJVCFzM5geE4A= forge.lthn.ai/core/go-log v0.0.4 h1:KTuCEPgFmuM8KJfnyQ8vPOU1Jg654W74h8IJvfQMfv0= forge.lthn.ai/core/go-log v0.0.4/go.mod h1:r14MXKOD3LF/sI8XUJQhRk/SZHBE7jAFVuCfgkXoZPw= -github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= diff --git a/internal/gguf/gguf.go b/internal/gguf/gguf.go index 9018dbe..1c45572 100644 --- a/internal/gguf/gguf.go +++ b/internal/gguf/gguf.go @@ -10,11 +10,11 @@ package gguf import ( "bufio" "encoding/binary" - "fmt" "io" "math" "os" - "strings" + + "dappco.re/go/core" coreerr "forge.lthn.ai/core/go-log" ) @@ -73,68 +73,68 @@ var fileTypeNames = map[uint32]string{ // FileTypeName returns a human-readable name for a GGML quantisation file type. // Unknown types return "type_N" where N is the numeric value. // -// gguf.FileTypeName(15) // "Q4_K_M" -// gguf.FileTypeName(7) // "Q8_0" -// gguf.FileTypeName(1) // "F16" +// name := gguf.FileTypeName(15) // "Q4_K_M" +// name := gguf.FileTypeName(17) // "Q5_K_M" func FileTypeName(ft uint32) string { if name, ok := fileTypeNames[ft]; ok { return name } - return fmt.Sprintf("type_%d", ft) + return core.Sprintf("type_%d", ft) } // ReadMetadata reads the GGUF header from the file at path and returns the // extracted metadata. Only metadata KV pairs are read; tensor data is not loaded. // -// meta, err := gguf.ReadMetadata("/data/model.gguf") -// fmt.Printf("arch=%s ctx=%d quant=%s", meta.Architecture, meta.ContextLength, gguf.FileTypeName(meta.FileType)) +// meta, err := gguf.ReadMetadata("/data/lem/gguf/gemma3-4b-q4km.gguf") +// // meta.Architecture == "gemma3", meta.ContextLength == 32768 func ReadMetadata(path string) (Metadata, error) { - f, err := os.Open(path) - if err != nil { - return Metadata{}, err + openResult := (&core.Fs{}).New("/").Open(path) + if !openResult.OK { + return Metadata{}, openResult.Value.(error) } - defer f.Close() + file := openResult.Value.(*os.File) + defer file.Close() - info, err := f.Stat() + info, err := file.Stat() if err != nil { return Metadata{}, err } - r := bufio.NewReader(f) + reader := bufio.NewReader(file) // Read and validate magic number. var magic uint32 - if err := binary.Read(r, binary.LittleEndian, &magic); err != nil { + if err := binary.Read(reader, binary.LittleEndian, &magic); err != nil { return Metadata{}, coreerr.E("gguf.ReadMetadata", "reading magic", err) } if magic != ggufMagic { - return Metadata{}, coreerr.E("gguf.ReadMetadata", fmt.Sprintf("invalid magic: 0x%08X (expected 0x%08X)", magic, ggufMagic), nil) + return Metadata{}, coreerr.E("gguf.ReadMetadata", core.Sprintf("invalid magic: 0x%08X (expected 0x%08X)", magic, ggufMagic), nil) } // Read version. var version uint32 - if err := binary.Read(r, binary.LittleEndian, &version); err != nil { + if err := binary.Read(reader, binary.LittleEndian, &version); err != nil { return Metadata{}, coreerr.E("gguf.ReadMetadata", "reading version", err) } if version < 2 || version > 3 { - return Metadata{}, coreerr.E("gguf.ReadMetadata", fmt.Sprintf("unsupported GGUF version: %d", version), nil) + return Metadata{}, coreerr.E("gguf.ReadMetadata", core.Sprintf("unsupported GGUF version: %d", version), nil) } // Read tensor count and KV count. v3 uses uint64, v2 uses uint32. var tensorCount, kvCount uint64 if version == 3 { - if err := binary.Read(r, binary.LittleEndian, &tensorCount); err != nil { + if err := binary.Read(reader, binary.LittleEndian, &tensorCount); err != nil { return Metadata{}, coreerr.E("gguf.ReadMetadata", "reading tensor count", err) } - if err := binary.Read(r, binary.LittleEndian, &kvCount); err != nil { + if err := binary.Read(reader, binary.LittleEndian, &kvCount); err != nil { return Metadata{}, coreerr.E("gguf.ReadMetadata", "reading kv count", err) } } else { var tensorCount32, kvCount32 uint32 - if err := binary.Read(r, binary.LittleEndian, &tensorCount32); err != nil { + if err := binary.Read(reader, binary.LittleEndian, &tensorCount32); err != nil { return Metadata{}, coreerr.E("gguf.ReadMetadata", "reading tensor count", err) } - if err := binary.Read(r, binary.LittleEndian, &kvCount32); err != nil { + if err := binary.Read(reader, binary.LittleEndian, &kvCount32); err != nil { return Metadata{}, coreerr.E("gguf.ReadMetadata", "reading kv count", err) } tensorCount = uint64(tensorCount32) @@ -155,76 +155,76 @@ func ReadMetadata(path string) (Metadata, error) { candidateBlockCount := make(map[string]uint32) for i := uint64(0); i < kvCount; i++ { - key, err := readString(r) + key, err := readString(reader) if err != nil { - return Metadata{}, coreerr.E("gguf.ReadMetadata", fmt.Sprintf("reading key %d", i), err) + return Metadata{}, coreerr.E("gguf.ReadMetadata", core.Sprintf("reading key %d", i), err) } var valType uint32 - if err := binary.Read(r, binary.LittleEndian, &valType); err != nil { - return Metadata{}, coreerr.E("gguf.ReadMetadata", fmt.Sprintf("reading value type for key %q", key), err) + if err := binary.Read(reader, binary.LittleEndian, &valType); err != nil { + return Metadata{}, coreerr.E("gguf.ReadMetadata", core.Sprintf("reading value type for key %q", key), err) } // Check whether this is an interesting key before reading the value. switch { case key == "general.architecture": - v, err := readTypedValue(r, valType) + rawValue, err := readTypedValue(reader, valType) if err != nil { - return Metadata{}, coreerr.E("gguf.ReadMetadata", fmt.Sprintf("reading value for key %q", key), err) + return Metadata{}, coreerr.E("gguf.ReadMetadata", core.Sprintf("reading value for key %q", key), err) } - if s, ok := v.(string); ok { - meta.Architecture = s + if stringValue, ok := rawValue.(string); ok { + meta.Architecture = stringValue } case key == "general.name": - v, err := readTypedValue(r, valType) + rawValue, err := readTypedValue(reader, valType) if err != nil { - return Metadata{}, coreerr.E("gguf.ReadMetadata", fmt.Sprintf("reading value for key %q", key), err) + return Metadata{}, coreerr.E("gguf.ReadMetadata", core.Sprintf("reading value for key %q", key), err) } - if s, ok := v.(string); ok { - meta.Name = s + if stringValue, ok := rawValue.(string); ok { + meta.Name = stringValue } case key == "general.file_type": - v, err := readTypedValue(r, valType) + rawValue, err := readTypedValue(reader, valType) if err != nil { - return Metadata{}, coreerr.E("gguf.ReadMetadata", fmt.Sprintf("reading value for key %q", key), err) + return Metadata{}, coreerr.E("gguf.ReadMetadata", core.Sprintf("reading value for key %q", key), err) } - if u, ok := v.(uint32); ok { - meta.FileType = u + if uint32Value, ok := rawValue.(uint32); ok { + meta.FileType = uint32Value } case key == "general.size_label": - v, err := readTypedValue(r, valType) + rawValue, err := readTypedValue(reader, valType) if err != nil { - return Metadata{}, coreerr.E("gguf.ReadMetadata", fmt.Sprintf("reading value for key %q", key), err) + return Metadata{}, coreerr.E("gguf.ReadMetadata", core.Sprintf("reading value for key %q", key), err) } - if s, ok := v.(string); ok { - meta.SizeLabel = s + if stringValue, ok := rawValue.(string); ok { + meta.SizeLabel = stringValue } - case strings.HasSuffix(key, ".context_length"): - v, err := readTypedValue(r, valType) + case core.HasSuffix(key, ".context_length"): + rawValue, err := readTypedValue(reader, valType) if err != nil { - return Metadata{}, coreerr.E("gguf.ReadMetadata", fmt.Sprintf("reading value for key %q", key), err) + return Metadata{}, coreerr.E("gguf.ReadMetadata", core.Sprintf("reading value for key %q", key), err) } - if u, ok := v.(uint32); ok { - candidateContextLength[key] = u + if uint32Value, ok := rawValue.(uint32); ok { + candidateContextLength[key] = uint32Value } - case strings.HasSuffix(key, ".block_count"): - v, err := readTypedValue(r, valType) + case core.HasSuffix(key, ".block_count"): + rawValue, err := readTypedValue(reader, valType) if err != nil { - return Metadata{}, coreerr.E("gguf.ReadMetadata", fmt.Sprintf("reading value for key %q", key), err) + return Metadata{}, coreerr.E("gguf.ReadMetadata", core.Sprintf("reading value for key %q", key), err) } - if u, ok := v.(uint32); ok { - candidateBlockCount[key] = u + if uint32Value, ok := rawValue.(uint32); ok { + candidateBlockCount[key] = uint32Value } default: // Skip uninteresting value. - if err := skipValue(r, valType); err != nil { - return Metadata{}, coreerr.E("gguf.ReadMetadata", fmt.Sprintf("skipping value for key %q", key), err) + if err := skipValue(reader, valType); err != nil { + return Metadata{}, coreerr.E("gguf.ReadMetadata", core.Sprintf("skipping value for key %q", key), err) } } } @@ -232,11 +232,11 @@ func ReadMetadata(path string) (Metadata, error) { // Resolve architecture-specific keys. if meta.Architecture != "" { prefix := meta.Architecture + "." - if v, ok := candidateContextLength[prefix+"context_length"]; ok { - meta.ContextLength = v + if contextLength, ok := candidateContextLength[prefix+"context_length"]; ok { + meta.ContextLength = contextLength } - if v, ok := candidateBlockCount[prefix+"block_count"]; ok { - meta.BlockCount = v + if blockCount, ok := candidateBlockCount[prefix+"block_count"]; ok { + meta.BlockCount = blockCount } } @@ -248,94 +248,94 @@ func ReadMetadata(path string) (Metadata, error) { const maxStringLength = 1 << 20 // readString reads a GGUF string: uint64 length followed by that many bytes. -func readString(r io.Reader) (string, error) { +func readString(reader io.Reader) (string, error) { var length uint64 - if err := binary.Read(r, binary.LittleEndian, &length); err != nil { + if err := binary.Read(reader, binary.LittleEndian, &length); err != nil { return "", err } if length > maxStringLength { - return "", coreerr.E("gguf.readString", fmt.Sprintf("string length %d exceeds maximum %d", length, maxStringLength), nil) + return "", coreerr.E("gguf.readString", core.Sprintf("string length %d exceeds maximum %d", length, maxStringLength), nil) } - buf := make([]byte, length) - if _, err := io.ReadFull(r, buf); err != nil { + buffer := make([]byte, length) + if _, err := io.ReadFull(reader, buffer); err != nil { return "", err } - return string(buf), nil + return string(buffer), nil } // readTypedValue reads a value of the given GGUF type and returns it as a Go // value. String, uint32, and uint64 types return typed values (uint64 is // downcast to uint32 when it fits). All others are read and discarded. -func readTypedValue(r io.Reader, valType uint32) (any, error) { +func readTypedValue(reader io.Reader, valType uint32) (any, error) { switch valType { case typeString: - return readString(r) + return readString(reader) case typeUint32: - var v uint32 - err := binary.Read(r, binary.LittleEndian, &v) - return v, err + var uint32Value uint32 + err := binary.Read(reader, binary.LittleEndian, &uint32Value) + return uint32Value, err case typeUint64: - var v uint64 - if err := binary.Read(r, binary.LittleEndian, &v); err != nil { + var uint64Value uint64 + if err := binary.Read(reader, binary.LittleEndian, &uint64Value); err != nil { return nil, err } - if v <= math.MaxUint32 { - return uint32(v), nil + if uint64Value <= math.MaxUint32 { + return uint32(uint64Value), nil } - return v, nil + return uint64Value, nil default: // Read and discard the value, returning nil. - err := skipValue(r, valType) + err := skipValue(reader, valType) return nil, err } } -// skipValue reads and discards a GGUF value of the given type from r. -func skipValue(r io.Reader, valType uint32) error { +// skipValue reads and discards a GGUF value of the given type from reader. +func skipValue(reader io.Reader, valType uint32) error { switch valType { case typeUint8, typeInt8, typeBool: - _, err := readN(r, 1) + _, err := readN(reader, 1) return err case typeUint16, typeInt16: - _, err := readN(r, 2) + _, err := readN(reader, 2) return err case typeUint32, typeInt32, typeFloat32: - _, err := readN(r, 4) + _, err := readN(reader, 4) return err case typeUint64, typeInt64, typeFloat64: - _, err := readN(r, 8) + _, err := readN(reader, 8) return err case typeString: var length uint64 - if err := binary.Read(r, binary.LittleEndian, &length); err != nil { + if err := binary.Read(reader, binary.LittleEndian, &length); err != nil { return err } if length > maxStringLength { - return coreerr.E("gguf.skipValue", fmt.Sprintf("string length %d exceeds maximum %d", length, maxStringLength), nil) + return coreerr.E("gguf.skipValue", core.Sprintf("string length %d exceeds maximum %d", length, maxStringLength), nil) } - _, err := readN(r, int64(length)) + _, err := readN(reader, int64(length)) return err case typeArray: var elemType uint32 - if err := binary.Read(r, binary.LittleEndian, &elemType); err != nil { + if err := binary.Read(reader, binary.LittleEndian, &elemType); err != nil { return err } var count uint64 - if err := binary.Read(r, binary.LittleEndian, &count); err != nil { + if err := binary.Read(reader, binary.LittleEndian, &count); err != nil { return err } for i := uint64(0); i < count; i++ { - if err := skipValue(r, elemType); err != nil { + if err := skipValue(reader, elemType); err != nil { return err } } return nil default: - return coreerr.E("gguf.skipValue", fmt.Sprintf("unknown GGUF value type: %d", valType), nil) + return coreerr.E("gguf.skipValue", core.Sprintf("unknown GGUF value type: %d", valType), nil) } } -// readN reads and discards exactly n bytes from r. -func readN(r io.Reader, n int64) (int64, error) { - return io.CopyN(io.Discard, r, n) +// readN reads and discards exactly n bytes from reader. +func readN(reader io.Reader, n int64) (int64, error) { + return io.CopyN(io.Discard, reader, n) } diff --git a/internal/gguf/gguf_test.go b/internal/gguf/gguf_test.go index ecc136b..fb78637 100644 --- a/internal/gguf/gguf_test.go +++ b/internal/gguf/gguf_test.go @@ -3,9 +3,10 @@ package gguf import ( "encoding/binary" "os" - "path/filepath" "testing" + "dappco.re/go/core" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -16,7 +17,7 @@ func writeTestGGUFOrdered(t *testing.T, kvs [][2]any) string { t.Helper() dir := t.TempDir() - path := filepath.Join(dir, "test.gguf") + path := core.Path(dir, "test.gguf") f, err := os.Create(path) require.NoError(t, err) @@ -86,7 +87,7 @@ func writeTestGGUFV2(t *testing.T, kvs [][2]any) string { t.Helper() dir := t.TempDir() - path := filepath.Join(dir, "test_v2.gguf") + path := core.Path(dir, "test_v2.gguf") f, err := os.Create(path) require.NoError(t, err) @@ -176,7 +177,7 @@ func TestReadMetadata_Ugly_ArchAfterContextLength(t *testing.T) { func TestReadMetadata_Bad_InvalidMagic(t *testing.T) { dir := t.TempDir() - path := filepath.Join(dir, "notgguf.bin") + path := core.Path(dir, "notgguf.bin") err := os.WriteFile(path, []byte("this is not a GGUF file at all"), 0644) require.NoError(t, err) @@ -196,7 +197,17 @@ func TestFileTypeName_Good(t *testing.T) { assert.Equal(t, "Q5_K_M", FileTypeName(17)) assert.Equal(t, "Q8_0", FileTypeName(7)) assert.Equal(t, "F16", FileTypeName(1)) - assert.Equal(t, "type_999", FileTypeName(999)) +} + +func TestFileTypeName_Bad_UnknownType(t *testing.T) { + // Unknown type numbers must not panic; return "type_N". + name := FileTypeName(999) + assert.Equal(t, "type_999", name) +} + +func TestFileTypeName_Ugly_ZeroType(t *testing.T) { + // Type 0 is F32 — the zero value must map to a valid name, not "type_0". + assert.Equal(t, "F32", FileTypeName(0)) } func TestReadMetadata_Good_V2(t *testing.T) { @@ -221,7 +232,7 @@ func TestReadMetadata_Good_V2(t *testing.T) { func TestReadMetadata_Bad_UnsupportedVersion(t *testing.T) { dir := t.TempDir() - path := filepath.Join(dir, "bad_version.gguf") + path := core.Path(dir, "bad_version.gguf") f, err := os.Create(path) require.NoError(t, err) @@ -239,7 +250,7 @@ func TestReadMetadata_Ugly_SkipsUnknownValueTypes(t *testing.T) { // Tests skipValue for uint8, int16, float32, uint64, bool, and array types. // These are stored under uninteresting keys so ReadMetadata skips them. dir := t.TempDir() - path := filepath.Join(dir, "skip_types.gguf") + path := core.Path(dir, "skip_types.gguf") f, err := os.Create(path) require.NoError(t, err) @@ -317,7 +328,7 @@ func TestReadMetadata_Ugly_Uint64ContextLength(t *testing.T) { func TestReadMetadata_Bad_TruncatedFile(t *testing.T) { dir := t.TempDir() - path := filepath.Join(dir, "truncated.gguf") + path := core.Path(dir, "truncated.gguf") // Write only the magic — no version or counts. f, err := os.Create(path) @@ -333,7 +344,7 @@ func TestReadMetadata_Bad_TruncatedFile(t *testing.T) { func TestReadMetadata_Ugly_SkipsStringValue(t *testing.T) { // Tests skipValue for string type (type 8) on an uninteresting key. dir := t.TempDir() - path := filepath.Join(dir, "skip_string.gguf") + path := core.Path(dir, "skip_string.gguf") f, err := os.Create(path) require.NoError(t, err) diff --git a/internal/llamacpp/client.go b/internal/llamacpp/client.go index d91970c..82eeb84 100644 --- a/internal/llamacpp/client.go +++ b/internal/llamacpp/client.go @@ -4,14 +4,13 @@ import ( "bufio" "bytes" "context" - "encoding/json" - "fmt" "io" "iter" "net/http" - "strings" "sync" + "dappco.re/go/core" + coreerr "forge.lthn.ai/core/go-log" ) @@ -60,55 +59,56 @@ type completionChunkResponse struct { } // ChatComplete sends a streaming chat completion request to /v1/chat/completions. -// Returns an iterator over text chunks and an error accessor called after ranging. +// It returns an iterator over text chunks and a function that returns any error +// that occurred during the request or while reading the stream. // -// chunks, errFn := client.ChatComplete(ctx, llamacpp.ChatRequest{ -// Messages: []llamacpp.ChatMessage{{Role: "user", Content: "Hello"}}, -// MaxTokens: 128, Temperature: 0.7, -// }) -// for text := range chunks { fmt.Print(text) } -// if err := errFn(); err != nil { /* handle */ } -func (c *Client) ChatComplete(ctx context.Context, req ChatRequest) (iter.Seq[string], func() error) { - req.Stream = true +// chunks, errorFunc := client.ChatComplete(ctx, ChatRequest{Messages: msgs, Temperature: 0.7}) +// for text := range chunks { output += text } +// if err := errorFunc(); err != nil { /* handle */ } +func (c *Client) ChatComplete(ctx context.Context, chatRequest ChatRequest) (iter.Seq[string], func() error) { + chatRequest.Stream = true - body, err := json.Marshal(req) - if err != nil { - return noChunks, func() error { return coreerr.E("llamacpp.ChatComplete", "marshal chat request", err) } + marshalResult := core.JSONMarshal(chatRequest) + if !marshalResult.OK { + marshalErr := marshalResult.Value.(error) + return noChunks, func() error { return coreerr.E("llamacpp.ChatComplete", "marshal chat request", marshalErr) } } + body := marshalResult.Value.([]byte) - httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+"/v1/chat/completions", bytes.NewReader(body)) + httpRequest, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+"/v1/chat/completions", bytes.NewReader(body)) if err != nil { return noChunks, func() error { return coreerr.E("llamacpp.ChatComplete", "create chat request", err) } } - httpReq.Header.Set("Content-Type", "application/json") - httpReq.Header.Set("Accept", "text/event-stream") + httpRequest.Header.Set("Content-Type", "application/json") + httpRequest.Header.Set("Accept", "text/event-stream") - resp, err := c.httpClient.Do(httpReq) + response, err := c.httpClient.Do(httpRequest) if err != nil { return noChunks, func() error { return coreerr.E("llamacpp.ChatComplete", "chat request", err) } } - if resp.StatusCode != http.StatusOK { - defer resp.Body.Close() - respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 256)) + if response.StatusCode != http.StatusOK { + defer response.Body.Close() + responseBody, _ := io.ReadAll(io.LimitReader(response.Body, 256)) return noChunks, func() error { - return coreerr.E("llamacpp.ChatComplete", fmt.Sprintf("chat returned %d: %s", resp.StatusCode, strings.TrimSpace(string(respBody))), nil) + return coreerr.E("llamacpp.ChatComplete", core.Sprintf("chat returned %d: %s", response.StatusCode, core.Trim(string(responseBody))), nil) } } var ( streamErr error closeOnce sync.Once - closeBody = func() { closeOnce.Do(func() { resp.Body.Close() }) } + closeBody = func() { closeOnce.Do(func() { response.Body.Close() }) } ) - sseData := parseSSE(resp.Body, &streamErr) + sseData := parseSSE(response.Body, &streamErr) tokens := func(yield func(string) bool) { defer closeBody() - for raw := range sseData { + for ssePayload := range sseData { var chunk chatChunkResponse - if err := json.Unmarshal([]byte(raw), &chunk); err != nil { - streamErr = coreerr.E("llamacpp.ChatComplete", "decode chat chunk", err) + decodeResult := core.JSONUnmarshal([]byte(ssePayload), &chunk) + if !decodeResult.OK { + streamErr = coreerr.E("llamacpp.ChatComplete", "decode chat chunk", decodeResult.Value.(error)) return } if len(chunk.Choices) == 0 { @@ -131,54 +131,56 @@ func (c *Client) ChatComplete(ctx context.Context, req ChatRequest) (iter.Seq[st } // Complete sends a streaming completion request to /v1/completions. -// Returns an iterator over text chunks and an error accessor called after ranging. +// It returns an iterator over text chunks and a function that returns any error +// that occurred during the request or while reading the stream. // -// chunks, errFn := client.Complete(ctx, llamacpp.CompletionRequest{ -// Prompt: "The capital of France is", MaxTokens: 16, Temperature: 0.0, -// }) -// for text := range chunks { fmt.Print(text) } -// if err := errFn(); err != nil { /* handle */ } -func (c *Client) Complete(ctx context.Context, req CompletionRequest) (iter.Seq[string], func() error) { - req.Stream = true +// chunks, errorFunc := client.Complete(ctx, CompletionRequest{Prompt: "Once upon", Temperature: 0.8}) +// for text := range chunks { output += text } +// if err := errorFunc(); err != nil { /* handle */ } +func (c *Client) Complete(ctx context.Context, completionRequest CompletionRequest) (iter.Seq[string], func() error) { + completionRequest.Stream = true - body, err := json.Marshal(req) - if err != nil { - return noChunks, func() error { return coreerr.E("llamacpp.Complete", "marshal completion request", err) } + marshalResult := core.JSONMarshal(completionRequest) + if !marshalResult.OK { + marshalErr := marshalResult.Value.(error) + return noChunks, func() error { return coreerr.E("llamacpp.Complete", "marshal completion request", marshalErr) } } + body := marshalResult.Value.([]byte) - httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+"/v1/completions", bytes.NewReader(body)) + httpRequest, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+"/v1/completions", bytes.NewReader(body)) if err != nil { return noChunks, func() error { return coreerr.E("llamacpp.Complete", "create completion request", err) } } - httpReq.Header.Set("Content-Type", "application/json") - httpReq.Header.Set("Accept", "text/event-stream") + httpRequest.Header.Set("Content-Type", "application/json") + httpRequest.Header.Set("Accept", "text/event-stream") - resp, err := c.httpClient.Do(httpReq) + response, err := c.httpClient.Do(httpRequest) if err != nil { return noChunks, func() error { return coreerr.E("llamacpp.Complete", "completion request", err) } } - if resp.StatusCode != http.StatusOK { - defer resp.Body.Close() - respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 256)) + if response.StatusCode != http.StatusOK { + defer response.Body.Close() + responseBody, _ := io.ReadAll(io.LimitReader(response.Body, 256)) return noChunks, func() error { - return coreerr.E("llamacpp.Complete", fmt.Sprintf("completion returned %d: %s", resp.StatusCode, strings.TrimSpace(string(respBody))), nil) + return coreerr.E("llamacpp.Complete", core.Sprintf("completion returned %d: %s", response.StatusCode, core.Trim(string(responseBody))), nil) } } var ( streamErr error closeOnce sync.Once - closeBody = func() { closeOnce.Do(func() { resp.Body.Close() }) } + closeBody = func() { closeOnce.Do(func() { response.Body.Close() }) } ) - sseData := parseSSE(resp.Body, &streamErr) + sseData := parseSSE(response.Body, &streamErr) tokens := func(yield func(string) bool) { defer closeBody() - for raw := range sseData { + for ssePayload := range sseData { var chunk completionChunkResponse - if err := json.Unmarshal([]byte(raw), &chunk); err != nil { - streamErr = coreerr.E("llamacpp.Complete", "decode completion chunk", err) + decodeResult := core.JSONUnmarshal([]byte(ssePayload), &chunk) + if !decodeResult.OK { + streamErr = coreerr.E("llamacpp.Complete", "decode completion chunk", decodeResult.Value.(error)) return } if len(chunk.Choices) == 0 { @@ -200,18 +202,18 @@ func (c *Client) Complete(ctx context.Context, req CompletionRequest) (iter.Seq[ } } -// parseSSE reads SSE-formatted lines from r and yields the payload of each +// parseSSE reads SSE-formatted lines from reader and yields the payload of each // "data: " line. It stops when it encounters "[DONE]" or an I/O error. // Any read error (other than EOF) is stored via errOut. -func parseSSE(r io.Reader, errOut *error) iter.Seq[string] { +func parseSSE(reader io.Reader, errOut *error) iter.Seq[string] { return func(yield func(string) bool) { - scanner := bufio.NewScanner(r) + scanner := bufio.NewScanner(reader) for scanner.Scan() { line := scanner.Text() - if !strings.HasPrefix(line, "data: ") { + if !core.HasPrefix(line, "data: ") { continue } - payload := strings.TrimPrefix(line, "data: ") + payload := core.TrimPrefix(line, "data: ") if payload == "[DONE]" { return } diff --git a/internal/llamacpp/client_test.go b/internal/llamacpp/client_test.go index 5e85b01..0b11840 100644 --- a/internal/llamacpp/client_test.go +++ b/internal/llamacpp/client_test.go @@ -2,7 +2,7 @@ package llamacpp import ( "context" - "fmt" + "io" "net/http" "net/http/httptest" "testing" @@ -22,7 +22,7 @@ func sseLines(w http.ResponseWriter, lines []string) { w.WriteHeader(http.StatusOK) for _, line := range lines { - fmt.Fprintf(w, "data: %s\n\n", line) + _, _ = io.WriteString(w, "data: "+line+"\n\n") f.Flush() } } @@ -114,7 +114,7 @@ func TestChatComplete_Ugly_ContextCancelled(t *testing.T) { w.WriteHeader(http.StatusOK) // Send first chunk. - fmt.Fprintf(w, "data: %s\n\n", `{"choices":[{"delta":{"content":"Hello"},"finish_reason":null}]}`) + _, _ = io.WriteString(w, "data: "+`{"choices":[{"delta":{"content":"Hello"},"finish_reason":null}]}`+"\n\n") f.Flush() // Wait for context cancellation before sending more. @@ -192,3 +192,43 @@ func TestComplete_Bad_HTTPError(t *testing.T) { require.Error(t, err) assert.Contains(t, err.Error(), "400") } + +func TestComplete_Ugly_ContextCancelled(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + f, ok := w.(http.Flusher) + if !ok { + panic("ResponseWriter does not implement Flusher") + } + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.WriteHeader(http.StatusOK) + + // Send first chunk. + _, _ = io.WriteString(w, "data: "+`{"choices":[{"text":"Once","finish_reason":null}]}`+"\n\n") + f.Flush() + + // Wait for context cancellation before sending more. + <-r.Context().Done() + })) + defer ts.Close() + + c := NewClient(ts.URL) + tokens, errFn := c.Complete(ctx, CompletionRequest{ + Prompt: "Once", + Temperature: 0.7, + Stream: true, + }) + + var got []string + for tok := range tokens { + got = append(got, tok) + cancel() // Cancel after receiving the first token. + } + // The error may or may not be nil depending on timing; + // the important thing is we got exactly 1 token. + _ = errFn() + assert.Equal(t, []string{"Once"}, got) +} diff --git a/internal/llamacpp/health.go b/internal/llamacpp/health.go index ecd422c..6979c4b 100644 --- a/internal/llamacpp/health.go +++ b/internal/llamacpp/health.go @@ -2,11 +2,10 @@ package llamacpp import ( "context" - "encoding/json" - "fmt" "io" "net/http" - "strings" + + "dappco.re/go/core" coreerr "forge.lthn.ai/core/go-log" ) @@ -20,12 +19,9 @@ type Client struct { // NewClient creates a client for the llama-server at the given base URL. // // client := llamacpp.NewClient("http://127.0.0.1:8080") -// if err := client.Health(ctx); err != nil { -// // server not ready -// } func NewClient(baseURL string) *Client { return &Client{ - baseURL: strings.TrimRight(baseURL, "/"), + baseURL: core.TrimSuffix(baseURL, "/"), httpClient: &http.Client{}, } } @@ -35,32 +31,36 @@ type healthResponse struct { } // Health checks whether the llama-server is ready to accept requests. -// Returns nil when the server responds with {"status":"ok"}. // -// if err := client.Health(ctx); err != nil { -// log.Printf("server not ready: %v", err) -// } +// if err := client.Health(ctx); err != nil { /* server not ready */ } func (c *Client) Health(ctx context.Context) error { - req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.baseURL+"/health", nil) + request, err := http.NewRequestWithContext(ctx, http.MethodGet, c.baseURL+"/health", nil) if err != nil { return err } - resp, err := c.httpClient.Do(req) + response, err := c.httpClient.Do(request) if err != nil { return err } - defer resp.Body.Close() + defer response.Body.Close() - if resp.StatusCode != http.StatusOK { - body, _ := io.ReadAll(io.LimitReader(resp.Body, 256)) - return coreerr.E("llamacpp.Health", fmt.Sprintf("health returned %d: %s", resp.StatusCode, string(body)), nil) + if response.StatusCode != http.StatusOK { + body, _ := io.ReadAll(io.LimitReader(response.Body, 256)) + return coreerr.E("llamacpp.Health", core.Sprintf("health returned %d: %s", response.StatusCode, string(body)), nil) } - var h healthResponse - if err := json.NewDecoder(resp.Body).Decode(&h); err != nil { - return coreerr.E("llamacpp.Health", "health decode", err) + var healthStatus healthResponse + decodeResult := core.JSONUnmarshal(mustReadAll(response.Body), &healthStatus) + if !decodeResult.OK { + return coreerr.E("llamacpp.Health", "health decode", decodeResult.Value.(error)) } - if h.Status != "ok" { - return coreerr.E("llamacpp.Health", fmt.Sprintf("server not ready (status: %s)", h.Status), nil) + if healthStatus.Status != "ok" { + return coreerr.E("llamacpp.Health", core.Sprintf("server not ready (status: %s)", healthStatus.Status), nil) } return nil } + +// mustReadAll reads all bytes from reader, returning nil on error. +func mustReadAll(reader io.Reader) []byte { + data, _ := io.ReadAll(reader) + return data +} diff --git a/internal/llamacpp/health_test.go b/internal/llamacpp/health_test.go index bb42e56..11d0619 100644 --- a/internal/llamacpp/health_test.go +++ b/internal/llamacpp/health_test.go @@ -53,3 +53,15 @@ func TestHealth_Bad_ServerDown(t *testing.T) { err := c.Health(context.Background()) assert.Error(t, err) } + +func TestHealth_Ugly_MalformedJSON(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`not valid json`)) + })) + defer ts.Close() + + c := NewClient(ts.URL) + err := c.Health(context.Background()) + assert.Error(t, err) +} diff --git a/model.go b/model.go index 98643d1..315303e 100644 --- a/model.go +++ b/model.go @@ -4,12 +4,12 @@ package rocm import ( "context" - "fmt" "iter" - "strings" "sync" "time" + "dappco.re/go/core" + coreerr "forge.lthn.ai/core/go-log" "forge.lthn.ai/core/go-inference" "forge.lthn.ai/core/go-rocm/internal/llamacpp" @@ -27,6 +27,10 @@ type rocmModel struct { } // Generate streams tokens for the given prompt via llama-server's /v1/completions endpoint. +// +// for tok := range m.Generate(ctx, "The capital of France is", inference.WithMaxTokens(32)) { +// output += tok.Text +// } func (m *rocmModel) Generate(ctx context.Context, prompt string, opts ...inference.GenerateOption) iter.Seq[inference.Token] { m.mu.Lock() m.lastErr = nil @@ -37,39 +41,44 @@ func (m *rocmModel) Generate(ctx context.Context, prompt string, opts ...inferen return func(yield func(inference.Token) bool) {} } - config := inference.ApplyGenerateOpts(opts) + configuration := inference.ApplyGenerateOpts(opts) - req := llamacpp.CompletionRequest{ + completionRequest := llamacpp.CompletionRequest{ Prompt: prompt, - MaxTokens: config.MaxTokens, - Temperature: config.Temperature, - TopK: config.TopK, - TopP: config.TopP, - RepeatPenalty: config.RepeatPenalty, + MaxTokens: configuration.MaxTokens, + Temperature: configuration.Temperature, + TopK: configuration.TopK, + TopP: configuration.TopP, + RepeatPenalty: configuration.RepeatPenalty, } start := time.Now() - chunks, errFn := m.server.client.Complete(ctx, req) + chunks, errorFunc := m.server.client.Complete(ctx, completionRequest) return func(yield func(inference.Token) bool) { - var count int + var tokenCount int decodeStart := time.Now() for text := range chunks { - count++ + tokenCount++ if !yield(inference.Token{Text: text}) { break } } - if err := errFn(); err != nil { + if err := errorFunc(); err != nil { m.mu.Lock() m.lastErr = err m.mu.Unlock() } - m.recordMetrics(0, count, start, decodeStart) + m.recordMetrics(0, tokenCount, start, decodeStart) } } // Chat streams tokens from a multi-turn conversation via llama-server's /v1/chat/completions endpoint. +// +// msgs := []inference.Message{{Role: "user", Content: "Hello"}} +// for tok := range m.Chat(ctx, msgs, inference.WithMaxTokens(64)) { +// output += tok.Text +// } func (m *rocmModel) Chat(ctx context.Context, messages []inference.Message, opts ...inference.GenerateOption) iter.Seq[inference.Token] { m.mu.Lock() m.lastErr = nil @@ -80,49 +89,52 @@ func (m *rocmModel) Chat(ctx context.Context, messages []inference.Message, opts return func(yield func(inference.Token) bool) {} } - config := inference.ApplyGenerateOpts(opts) + configuration := inference.ApplyGenerateOpts(opts) - chatMsgs := make([]llamacpp.ChatMessage, len(messages)) + chatMessages := make([]llamacpp.ChatMessage, len(messages)) for i, msg := range messages { - chatMsgs[i] = llamacpp.ChatMessage{ + chatMessages[i] = llamacpp.ChatMessage{ Role: msg.Role, Content: msg.Content, } } - req := llamacpp.ChatRequest{ - Messages: chatMsgs, - MaxTokens: config.MaxTokens, - Temperature: config.Temperature, - TopK: config.TopK, - TopP: config.TopP, - RepeatPenalty: config.RepeatPenalty, + chatRequest := llamacpp.ChatRequest{ + Messages: chatMessages, + MaxTokens: configuration.MaxTokens, + Temperature: configuration.Temperature, + TopK: configuration.TopK, + TopP: configuration.TopP, + RepeatPenalty: configuration.RepeatPenalty, } start := time.Now() - chunks, errFn := m.server.client.ChatComplete(ctx, req) + chunks, errorFunc := m.server.client.ChatComplete(ctx, chatRequest) return func(yield func(inference.Token) bool) { - var count int + var tokenCount int decodeStart := time.Now() for text := range chunks { - count++ + tokenCount++ if !yield(inference.Token{Text: text}) { break } } - if err := errFn(); err != nil { + if err := errorFunc(); err != nil { m.mu.Lock() m.lastErr = err m.mu.Unlock() } - m.recordMetrics(0, count, start, decodeStart) + m.recordMetrics(0, tokenCount, start, decodeStart) } } // Classify runs batched prefill-only inference via llama-server. // Each prompt gets a single-token completion (max_tokens=1, temperature=0). // llama-server has no native classify endpoint, so this simulates it. +// +// results, err := m.Classify(ctx, []string{"positive review", "negative review"}) +// // results[0].Token.Text == "pos" or similar top token func (m *rocmModel) Classify(ctx context.Context, prompts []string, opts ...inference.GenerateOption) ([]inference.ClassifyResult, error) { if !m.server.alive() { m.setServerExitErr() @@ -137,23 +149,23 @@ func (m *rocmModel) Classify(ctx context.Context, prompts []string, opts ...infe return nil, ctx.Err() } - req := llamacpp.CompletionRequest{ + completionRequest := llamacpp.CompletionRequest{ Prompt: prompt, MaxTokens: 1, Temperature: 0, } - chunks, errFn := m.server.client.Complete(ctx, req) - var text strings.Builder + chunks, errorFunc := m.server.client.Complete(ctx, completionRequest) + builder := core.NewBuilder() for chunk := range chunks { - text.WriteString(chunk) + builder.WriteString(chunk) } - if err := errFn(); err != nil { - return nil, coreerr.E("rocm.Classify", fmt.Sprintf("classify prompt %d", i), err) + if err := errorFunc(); err != nil { + return nil, coreerr.E("rocm.Classify", core.Sprintf("classify prompt %d", i), err) } results[i] = inference.ClassifyResult{ - Token: inference.Token{Text: text.String()}, + Token: inference.Token{Text: builder.String()}, } } @@ -163,13 +175,16 @@ func (m *rocmModel) Classify(ctx context.Context, prompts []string, opts ...infe // BatchGenerate runs batched autoregressive generation via llama-server. // Each prompt is decoded sequentially up to MaxTokens. +// +// results, err := m.BatchGenerate(ctx, []string{"prompt A", "prompt B"}, inference.WithMaxTokens(64)) +// // results[0].Tokens — tokens generated for prompt A func (m *rocmModel) BatchGenerate(ctx context.Context, prompts []string, opts ...inference.GenerateOption) ([]inference.BatchResult, error) { if !m.server.alive() { m.setServerExitErr() return nil, m.Err() } - config := inference.ApplyGenerateOpts(opts) + configuration := inference.ApplyGenerateOpts(opts) start := time.Now() results := make([]inference.BatchResult, len(prompts)) var totalGenerated int @@ -180,22 +195,22 @@ func (m *rocmModel) BatchGenerate(ctx context.Context, prompts []string, opts .. continue } - req := llamacpp.CompletionRequest{ + completionRequest := llamacpp.CompletionRequest{ Prompt: prompt, - MaxTokens: config.MaxTokens, - Temperature: config.Temperature, - TopK: config.TopK, - TopP: config.TopP, - RepeatPenalty: config.RepeatPenalty, + MaxTokens: configuration.MaxTokens, + Temperature: configuration.Temperature, + TopK: configuration.TopK, + TopP: configuration.TopP, + RepeatPenalty: configuration.RepeatPenalty, } - chunks, errFn := m.server.client.Complete(ctx, req) + chunks, errorFunc := m.server.client.Complete(ctx, completionRequest) var tokens []inference.Token for text := range chunks { tokens = append(tokens, inference.Token{Text: text}) } - if err := errFn(); err != nil { - results[i].Err = coreerr.E("rocm.BatchGenerate", fmt.Sprintf("batch prompt %d", i), err) + if err := errorFunc(); err != nil { + results[i].Err = coreerr.E("rocm.BatchGenerate", core.Sprintf("batch prompt %d", i), err) } results[i].Tokens = tokens totalGenerated += len(tokens) @@ -206,12 +221,20 @@ func (m *rocmModel) BatchGenerate(ctx context.Context, prompts []string, opts .. } // ModelType returns the architecture identifier (e.g. "gemma3", "qwen3", "llama3"). +// +// arch := m.ModelType() // "gemma3" func (m *rocmModel) ModelType() string { return m.modelType } // Info returns metadata about the loaded model. +// +// info := m.Info() +// // info.Architecture == "gemma3", info.NumLayers == 26 func (m *rocmModel) Info() inference.ModelInfo { return m.modelInfo } // Metrics returns performance metrics from the last inference operation. +// +// metrics := m.Metrics() +// // metrics.DecodeTokensPerSec, metrics.TotalDuration func (m *rocmModel) Metrics() inference.GenerateMetrics { m.mu.Lock() defer m.mu.Unlock() @@ -219,6 +242,9 @@ func (m *rocmModel) Metrics() inference.GenerateMetrics { } // Err returns the error from the last Generate/Chat call, if any. +// +// for tok := range m.Generate(ctx, prompt) { } +// if err := m.Err(); err != nil { /* handle */ } func (m *rocmModel) Err() error { m.mu.Lock() defer m.mu.Unlock() @@ -226,6 +252,9 @@ func (m *rocmModel) Err() error { } // Close releases the llama-server subprocess and all associated resources. +// +// m, err := backend.LoadModel("/data/model.gguf") +// defer m.Close() func (m *rocmModel) Close() error { return m.server.stop() } @@ -248,7 +277,7 @@ func (m *rocmModel) recordMetrics(promptTokens, generatedTokens int, start, deco decode := now.Sub(decodeStart) prefill := total - decode - metrics := inference.GenerateMetrics{ + result := inference.GenerateMetrics{ PromptTokens: promptTokens, GeneratedTokens: generatedTokens, PrefillDuration: prefill, @@ -256,19 +285,19 @@ func (m *rocmModel) recordMetrics(promptTokens, generatedTokens int, start, deco TotalDuration: total, } if prefill > 0 && promptTokens > 0 { - metrics.PrefillTokensPerSec = float64(promptTokens) / prefill.Seconds() + result.PrefillTokensPerSec = float64(promptTokens) / prefill.Seconds() } if decode > 0 && generatedTokens > 0 { - metrics.DecodeTokensPerSec = float64(generatedTokens) / decode.Seconds() + result.DecodeTokensPerSec = float64(generatedTokens) / decode.Seconds() } // Try to get VRAM stats — best effort. - if vram, err := GetVRAMInfo(); err == nil { - metrics.PeakMemoryBytes = vram.Used - metrics.ActiveMemoryBytes = vram.Used + if vramInfo, err := GetVRAMInfo(); err == nil { + result.PeakMemoryBytes = vramInfo.Used + result.ActiveMemoryBytes = vramInfo.Used } m.mu.Lock() - m.metrics = metrics + m.metrics = result m.mu.Unlock() } diff --git a/rocm.go b/rocm.go index 4518ecd..39612c2 100644 --- a/rocm.go +++ b/rocm.go @@ -12,8 +12,9 @@ // // m, err := inference.LoadModel("/path/to/model.gguf") // defer m.Close() +// output := core.NewBuilder() // for tok := range m.Generate(ctx, "Hello", inference.WithMaxTokens(128)) { -// fmt.Print(tok.Text) +// output.WriteString(tok.Text) // } // // # Requirements @@ -27,7 +28,7 @@ package rocm // VRAMInfo reports GPU video memory usage in bytes. // // info, err := rocm.GetVRAMInfo() -// fmt.Printf("VRAM: %d MiB used / %d MiB total", info.Used/(1024*1024), info.Total/(1024*1024)) +// core.Print(c, "VRAM: %d MiB used / %d MiB total", info.Used/(1024*1024), info.Total/(1024*1024)) type VRAMInfo struct { Total uint64 Used uint64 @@ -38,7 +39,7 @@ type VRAMInfo struct { // // models, _ := rocm.DiscoverModels("/data/lem/gguf") // for _, m := range models { -// fmt.Printf("%s (%s %s, ctx=%d)\n", m.Name, m.Architecture, m.Quantisation, m.ContextLen) +// core.Print(c, "%s (%s %s, ctx=%d)", m.Name, m.Architecture, m.Quantisation, m.ContextLen) // } type ModelInfo struct { Path string // full path to .gguf file diff --git a/rocm_integration_test.go b/rocm_integration_test.go index 1856036..5253c64 100644 --- a/rocm_integration_test.go +++ b/rocm_integration_test.go @@ -5,12 +5,12 @@ package rocm import ( "context" "os" - "path/filepath" - "strings" "sync" "testing" "time" + "dappco.re/go/core" + "forge.lthn.ai/core/go-inference" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -35,45 +35,45 @@ func skipIfNoROCm(t *testing.T) { } } -func TestROCm_LoadAndGenerate(t *testing.T) { +func TestROCm_LoadAndGenerate_Good(t *testing.T) { skipIfNoROCm(t) skipIfNoModel(t) - b := &rocmBackend{} - require.True(t, b.Available()) + backend := &rocmBackend{} + require.True(t, backend.Available()) - m, err := b.LoadModel(testModel, inference.WithContextLen(2048)) + model, err := backend.LoadModel(testModel, inference.WithContextLen(2048)) require.NoError(t, err) - defer m.Close() + defer model.Close() - assert.Equal(t, "gemma3", m.ModelType()) + assert.Equal(t, "gemma3", model.ModelType()) ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() - var tokens []string - for tok := range m.Generate(ctx, "The capital of France is", inference.WithMaxTokens(16)) { - tokens = append(tokens, tok.Text) + var tokenTexts []string + for tok := range model.Generate(ctx, "The capital of France is", inference.WithMaxTokens(16)) { + tokenTexts = append(tokenTexts, tok.Text) } - require.NoError(t, m.Err()) - require.NotEmpty(t, tokens, "expected at least one token") + require.NoError(t, model.Err()) + require.NotEmpty(t, tokenTexts, "expected at least one token") - full := "" - for _, tok := range tokens { - full += tok + fullText := "" + for _, tok := range tokenTexts { + fullText += tok } - t.Logf("Generated: %s", full) + t.Logf("Generated: %s", fullText) } -func TestROCm_Chat(t *testing.T) { +func TestROCm_Chat_Good(t *testing.T) { skipIfNoROCm(t) skipIfNoModel(t) - b := &rocmBackend{} - m, err := b.LoadModel(testModel, inference.WithContextLen(2048)) + backend := &rocmBackend{} + model, err := backend.LoadModel(testModel, inference.WithContextLen(2048)) require.NoError(t, err) - defer m.Close() + defer model.Close() ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() @@ -82,89 +82,89 @@ func TestROCm_Chat(t *testing.T) { {Role: "user", Content: "Say hello in exactly three words."}, } - var tokens []string - for tok := range m.Chat(ctx, messages, inference.WithMaxTokens(32)) { - tokens = append(tokens, tok.Text) + var tokenTexts []string + for tok := range model.Chat(ctx, messages, inference.WithMaxTokens(32)) { + tokenTexts = append(tokenTexts, tok.Text) } - require.NoError(t, m.Err()) - require.NotEmpty(t, tokens, "expected at least one token") + require.NoError(t, model.Err()) + require.NotEmpty(t, tokenTexts, "expected at least one token") - full := "" - for _, tok := range tokens { - full += tok + fullText := "" + for _, tok := range tokenTexts { + fullText += tok } - t.Logf("Chat response: %s", full) + t.Logf("Chat response: %s", fullText) } -func TestROCm_ContextCancellation(t *testing.T) { +func TestROCm_ContextCancellation_Ugly(t *testing.T) { skipIfNoROCm(t) skipIfNoModel(t) - b := &rocmBackend{} - m, err := b.LoadModel(testModel, inference.WithContextLen(2048)) + backend := &rocmBackend{} + model, err := backend.LoadModel(testModel, inference.WithContextLen(2048)) require.NoError(t, err) - defer m.Close() + defer model.Close() ctx, cancel := context.WithCancel(context.Background()) - var count int - for tok := range m.Generate(ctx, "Write a very long story about dragons", inference.WithMaxTokens(256)) { + var tokenCount int + for tok := range model.Generate(ctx, "Write a very long story about dragons", inference.WithMaxTokens(256)) { _ = tok - count++ - if count >= 3 { + tokenCount++ + if tokenCount >= 3 { cancel() } } - t.Logf("Got %d tokens before cancel", count) - assert.GreaterOrEqual(t, count, 3) + t.Logf("Got %d tokens before cancel", tokenCount) + assert.GreaterOrEqual(t, tokenCount, 3) } -func TestROCm_GracefulShutdown(t *testing.T) { +func TestROCm_GracefulShutdown_Good(t *testing.T) { skipIfNoROCm(t) skipIfNoModel(t) - b := &rocmBackend{} - m, err := b.LoadModel(testModel, inference.WithContextLen(2048)) + backend := &rocmBackend{} + model, err := backend.LoadModel(testModel, inference.WithContextLen(2048)) require.NoError(t, err) - defer m.Close() + defer model.Close() // Cancel mid-stream. ctx1, cancel1 := context.WithCancel(context.Background()) - var count1 int - for tok := range m.Generate(ctx1, "Write a long story about space exploration", inference.WithMaxTokens(256)) { + var firstTokenCount int + for tok := range model.Generate(ctx1, "Write a long story about space exploration", inference.WithMaxTokens(256)) { _ = tok - count1++ - if count1 >= 5 { + firstTokenCount++ + if firstTokenCount >= 5 { cancel1() } } - t.Logf("First generation: %d tokens before cancel", count1) + t.Logf("First generation: %d tokens before cancel", firstTokenCount) // Generate again on the same model — server should still be alive. ctx2, cancel2 := context.WithTimeout(context.Background(), 30*time.Second) defer cancel2() - var count2 int - for tok := range m.Generate(ctx2, "The capital of France is", inference.WithMaxTokens(16)) { + var secondTokenCount int + for tok := range model.Generate(ctx2, "The capital of France is", inference.WithMaxTokens(16)) { _ = tok - count2++ + secondTokenCount++ } - require.NoError(t, m.Err()) - assert.Greater(t, count2, 0, "expected tokens from second generation after cancel") - t.Logf("Second generation: %d tokens", count2) + require.NoError(t, model.Err()) + assert.Greater(t, secondTokenCount, 0, "expected tokens from second generation after cancel") + t.Logf("Second generation: %d tokens", secondTokenCount) } -func TestROCm_ConcurrentRequests(t *testing.T) { +func TestROCm_ConcurrentRequests_Good(t *testing.T) { skipIfNoROCm(t) skipIfNoModel(t) - b := &rocmBackend{} - m, err := b.LoadModel(testModel, inference.WithContextLen(2048)) + backend := &rocmBackend{} + model, err := backend.LoadModel(testModel, inference.WithContextLen(2048)) require.NoError(t, err) - defer m.Close() + defer model.Close() const numGoroutines = 3 results := make([]string, numGoroutines) @@ -175,25 +175,25 @@ func TestROCm_ConcurrentRequests(t *testing.T) { "The capital of Italy is", } - var wg sync.WaitGroup - wg.Add(numGoroutines) + var waitGroup sync.WaitGroup + waitGroup.Add(numGoroutines) for i := range numGoroutines { go func(idx int) { - defer wg.Done() + defer waitGroup.Done() ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() - var sb strings.Builder - for tok := range m.Generate(ctx, prompts[idx], inference.WithMaxTokens(16)) { - sb.WriteString(tok.Text) + builder := core.NewBuilder() + for tok := range model.Generate(ctx, prompts[idx], inference.WithMaxTokens(16)) { + builder.WriteString(tok.Text) } - results[idx] = sb.String() + results[idx] = builder.String() }(i) } - wg.Wait() + waitGroup.Wait() for i, result := range results { t.Logf("Goroutine %d: %s", i, result) @@ -201,14 +201,14 @@ func TestROCm_ConcurrentRequests(t *testing.T) { } } -func TestROCm_Classify(t *testing.T) { +func TestROCm_Classify_Good(t *testing.T) { skipIfNoROCm(t) skipIfNoModel(t) - b := &rocmBackend{} - m, err := b.LoadModel(testModel, inference.WithContextLen(2048)) + backend := &rocmBackend{} + model, err := backend.LoadModel(testModel, inference.WithContextLen(2048)) require.NoError(t, err) - defer m.Close() + defer model.Close() ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() @@ -218,24 +218,24 @@ func TestROCm_Classify(t *testing.T) { "2 + 2 =", } - results, err := m.Classify(ctx, prompts) + results, err := model.Classify(ctx, prompts) require.NoError(t, err) require.Len(t, results, 2) - for i, r := range results { - assert.NotEmpty(t, r.Token.Text, "classify result %d should have a token", i) - t.Logf("Classify %d: %q", i, r.Token.Text) + for i, classifyResult := range results { + assert.NotEmpty(t, classifyResult.Token.Text, "classify result %d should have a token", i) + t.Logf("Classify %d: %q", i, classifyResult.Token.Text) } } -func TestROCm_BatchGenerate(t *testing.T) { +func TestROCm_BatchGenerate_Good(t *testing.T) { skipIfNoROCm(t) skipIfNoModel(t) - b := &rocmBackend{} - m, err := b.LoadModel(testModel, inference.WithContextLen(2048)) + backend := &rocmBackend{} + model, err := backend.LoadModel(testModel, inference.WithContextLen(2048)) require.NoError(t, err) - defer m.Close() + defer model.Close() ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) defer cancel() @@ -245,70 +245,70 @@ func TestROCm_BatchGenerate(t *testing.T) { "The capital of Germany is", } - results, err := m.BatchGenerate(ctx, prompts, inference.WithMaxTokens(8)) + results, err := model.BatchGenerate(ctx, prompts, inference.WithMaxTokens(8)) require.NoError(t, err) require.Len(t, results, 2) - for i, r := range results { - require.NoError(t, r.Err, "batch result %d error", i) - assert.NotEmpty(t, r.Tokens, "batch result %d should have tokens", i) + for i, batchResult := range results { + require.NoError(t, batchResult.Err, "batch result %d error", i) + assert.NotEmpty(t, batchResult.Tokens, "batch result %d should have tokens", i) - var sb strings.Builder - for _, tok := range r.Tokens { - sb.WriteString(tok.Text) + builder := core.NewBuilder() + for _, tok := range batchResult.Tokens { + builder.WriteString(tok.Text) } - t.Logf("Batch %d: %s", i, sb.String()) + t.Logf("Batch %d: %s", i, builder.String()) } } -func TestROCm_InfoAndMetrics(t *testing.T) { +func TestROCm_InfoAndMetrics_Good(t *testing.T) { skipIfNoROCm(t) skipIfNoModel(t) - b := &rocmBackend{} - m, err := b.LoadModel(testModel, inference.WithContextLen(2048)) + backend := &rocmBackend{} + model, err := backend.LoadModel(testModel, inference.WithContextLen(2048)) require.NoError(t, err) - defer m.Close() + defer model.Close() // Info should be populated from GGUF metadata. - info := m.Info() - assert.Equal(t, "gemma3", info.Architecture) - assert.Greater(t, info.NumLayers, 0, "expected non-zero layer count") - assert.Greater(t, info.QuantBits, 0, "expected non-zero quant bits") + modelInfo := model.Info() + assert.Equal(t, "gemma3", modelInfo.Architecture) + assert.Greater(t, modelInfo.NumLayers, 0, "expected non-zero layer count") + assert.Greater(t, modelInfo.QuantBits, 0, "expected non-zero quant bits") t.Logf("Info: arch=%s layers=%d quant=%d-bit group=%d", - info.Architecture, info.NumLayers, info.QuantBits, info.QuantGroup) + modelInfo.Architecture, modelInfo.NumLayers, modelInfo.QuantBits, modelInfo.QuantGroup) // Generate some tokens to populate metrics. ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() - for range m.Generate(ctx, "Hello", inference.WithMaxTokens(4)) { + for range model.Generate(ctx, "Hello", inference.WithMaxTokens(4)) { } - require.NoError(t, m.Err()) + require.NoError(t, model.Err()) - met := m.Metrics() - assert.Greater(t, met.GeneratedTokens, 0, "expected generated tokens") - assert.Greater(t, met.TotalDuration, time.Duration(0), "expected non-zero duration") - assert.Greater(t, met.DecodeTokensPerSec, float64(0), "expected non-zero decode throughput") + metrics := model.Metrics() + assert.Greater(t, metrics.GeneratedTokens, 0, "expected generated tokens") + assert.Greater(t, metrics.TotalDuration, time.Duration(0), "expected non-zero duration") + assert.Greater(t, metrics.DecodeTokensPerSec, float64(0), "expected non-zero decode throughput") t.Logf("Metrics: gen=%d tok, total=%s, decode=%.1f tok/s, vram=%d MiB", - met.GeneratedTokens, met.TotalDuration, met.DecodeTokensPerSec, - met.ActiveMemoryBytes/(1024*1024)) + metrics.GeneratedTokens, metrics.TotalDuration, metrics.DecodeTokensPerSec, + metrics.ActiveMemoryBytes/(1024*1024)) } -func TestROCm_DiscoverModels(t *testing.T) { - dir := filepath.Dir(testModel) - if _, err := os.Stat(dir); err != nil { +func TestROCm_DiscoverModels_Good(t *testing.T) { + modelDir := core.PathDir(testModel) + if _, err := os.Stat(modelDir); err != nil { t.Skip("model directory not available") } - models, err := DiscoverModels(dir) + models, err := DiscoverModels(modelDir) require.NoError(t, err) - require.NotEmpty(t, models, "expected at least one model in %s", dir) + require.NotEmpty(t, models, "expected at least one model in %s", modelDir) - for _, m := range models { - t.Logf("Found: %s (%s %s %s, ctx=%d)", filepath.Base(m.Path), m.Architecture, m.Parameters, m.Quantisation, m.ContextLen) - assert.NotEmpty(t, m.Architecture) - assert.NotEmpty(t, m.Name) - assert.Greater(t, m.FileSize, int64(0)) + for _, discoveredModel := range models { + t.Logf("Found: %s (%s %s %s, ctx=%d)", core.PathBase(discoveredModel.Path), discoveredModel.Architecture, discoveredModel.Parameters, discoveredModel.Quantisation, discoveredModel.ContextLen) + assert.NotEmpty(t, discoveredModel.Architecture) + assert.NotEmpty(t, discoveredModel.Name) + assert.Greater(t, discoveredModel.FileSize, int64(0)) } } diff --git a/server.go b/server.go index 071e759..7757029 100644 --- a/server.go +++ b/server.go @@ -4,15 +4,15 @@ package rocm import ( "context" - "fmt" "net" "os" "os/exec" "strconv" - "strings" "syscall" "time" + "dappco.re/go/core" + coreerr "forge.lthn.ai/core/go-log" "forge.lthn.ai/core/go-rocm/internal/llamacpp" ) @@ -38,28 +38,31 @@ func (s *server) alive() bool { // findLlamaServer locates the llama-server binary. // Checks ROCM_LLAMA_SERVER_PATH first, then PATH. +// +// path, err := findLlamaServer() +// // path == "/usr/local/bin/llama-server" func findLlamaServer() (string, error) { - if p := os.Getenv("ROCM_LLAMA_SERVER_PATH"); p != "" { - if _, err := os.Stat(p); err != nil { - return "", coreerr.E("rocm.findLlamaServer", "llama-server not found at ROCM_LLAMA_SERVER_PATH="+p, err) + if binaryPath := core.Env("ROCM_LLAMA_SERVER_PATH"); binaryPath != "" { + if !(&core.Fs{}).New("/").Exists(binaryPath) { + return "", coreerr.E("rocm.findLlamaServer", "llama-server not found at ROCM_LLAMA_SERVER_PATH="+binaryPath, nil) } - return p, nil + return binaryPath, nil } - p, err := exec.LookPath("llama-server") + binaryPath, err := exec.LookPath("llama-server") if err != nil { return "", coreerr.E("rocm.findLlamaServer", "llama-server not found in PATH", err) } - return p, nil + return binaryPath, nil } // freePort asks the kernel for a free TCP port on localhost. func freePort() (int, error) { - ln, err := net.Listen("tcp", "127.0.0.1:0") + listener, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { return 0, coreerr.E("rocm.freePort", "listen for free port", err) } - port := ln.Addr().(*net.TCPAddr).Port - ln.Close() + port := listener.Addr().(*net.TCPAddr).Port + listener.Close() return port, nil } @@ -69,11 +72,11 @@ func freePort() (int, error) { func serverEnv() []string { environ := os.Environ() env := make([]string, 0, len(environ)+1) - for _, e := range environ { - if strings.HasPrefix(e, "HIP_VISIBLE_DEVICES=") { + for _, envEntry := range environ { + if core.HasPrefix(envEntry, "HIP_VISIBLE_DEVICES=") { continue } - env = append(env, e) + env = append(env, envEntry) } env = append(env, "HIP_VISIBLE_DEVICES=0") return env @@ -82,7 +85,10 @@ func serverEnv() []string { // startServer spawns llama-server and waits for it to become ready. // It selects a free port automatically, retrying up to 3 times if the // process exits during startup (e.g. port conflict). -func startServer(binary, modelPath string, gpuLayers, ctxSize, parallelSlots int) (*server, error) { +// +// s, err := startServer("/usr/local/bin/llama-server", "/data/model.gguf", 99, 4096, 4) +// defer s.stop() +func startServer(binary, modelPath string, gpuLayers, contextSize, parallelSlots int) (*server, error) { if gpuLayers < 0 { gpuLayers = 999 } @@ -102,8 +108,8 @@ func startServer(binary, modelPath string, gpuLayers, ctxSize, parallelSlots int "--port", strconv.Itoa(port), "--n-gpu-layers", strconv.Itoa(gpuLayers), } - if ctxSize > 0 { - args = append(args, "--ctx-size", strconv.Itoa(ctxSize)) + if contextSize > 0 { + args = append(args, "--ctx-size", strconv.Itoa(contextSize)) } if parallelSlots > 0 { args = append(args, "--parallel", strconv.Itoa(parallelSlots)) @@ -116,39 +122,39 @@ func startServer(binary, modelPath string, gpuLayers, ctxSize, parallelSlots int return nil, coreerr.E("rocm.startServer", "start llama-server", err) } - s := &server{ + subprocess := &server{ cmd: cmd, port: port, - client: llamacpp.NewClient(fmt.Sprintf("http://127.0.0.1:%d", port)), + client: llamacpp.NewClient(core.Sprintf("http://127.0.0.1:%d", port)), exited: make(chan struct{}), } go func() { - s.exitErr = cmd.Wait() - close(s.exited) + subprocess.exitErr = cmd.Wait() + close(subprocess.exited) }() ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) - err = s.waitReady(ctx) + err = subprocess.waitReady(ctx) cancel() if err == nil { - return s, nil + return subprocess, nil } // Only retry if the process actually exited (e.g. port conflict). // A timeout means the server is stuck, not a port issue. select { - case <-s.exited: - _ = s.stop() - lastErr = coreerr.E("rocm.startServer", fmt.Sprintf("attempt %d", attempt+1), err) + case <-subprocess.exited: + _ = subprocess.stop() + lastErr = coreerr.E("rocm.startServer", core.Sprintf("attempt %d", attempt+1), err) continue default: - _ = s.stop() + _ = subprocess.stop() return nil, coreerr.E("rocm.startServer", "llama-server not ready", err) } } - return nil, coreerr.E("rocm.startServer", fmt.Sprintf("server failed after %d attempts", maxAttempts), lastErr) + return nil, coreerr.E("rocm.startServer", core.Sprintf("server failed after %d attempts", maxAttempts), lastErr) } // waitReady polls the health endpoint until the server is ready. diff --git a/server_test.go b/server_test.go index 9d4e527..1c529a8 100644 --- a/server_test.go +++ b/server_test.go @@ -4,10 +4,10 @@ package rocm import ( "context" - "os" - "strings" "testing" + "dappco.re/go/core" + "forge.lthn.ai/core/go-inference" coreerr "forge.lthn.ai/core/go-log" "github.com/stretchr/testify/assert" @@ -34,6 +34,14 @@ func TestFindLlamaServer_Bad_EnvPathMissing(t *testing.T) { assert.ErrorContains(t, err, "not found") } +func TestFindLlamaServer_Ugly_EmptyPATH(t *testing.T) { + // With no ROCM_LLAMA_SERVER_PATH set and an empty PATH, LookPath must fail. + t.Setenv("ROCM_LLAMA_SERVER_PATH", "") + t.Setenv("PATH", "") + _, err := findLlamaServer() + assert.Error(t, err) +} + func TestFreePort_Good(t *testing.T) { port, err := freePort() require.NoError(t, err) @@ -50,11 +58,26 @@ func TestFreePort_Good_UniquePerCall(t *testing.T) { _ = p2 } +func TestFreePort_Bad_InvalidAddr(t *testing.T) { + // freePort always binds to localhost so it can't fail on a valid machine; + // this test documents that the port is always in the valid range. + port, err := freePort() + require.NoError(t, err) + assert.Greater(t, port, 1023, "expected unprivileged port") +} + +func TestFreePort_Ugly_ReturnsUsablePort(t *testing.T) { + // The returned port should be bindable a second time. + port, err := freePort() + require.NoError(t, err) + assert.NotZero(t, port) +} + func TestServerEnv_Good_SetsHIPVisibleDevices(t *testing.T) { env := serverEnv() var hipVals []string for _, e := range env { - if strings.HasPrefix(e, "HIP_VISIBLE_DEVICES=") { + if core.HasPrefix(e, "HIP_VISIBLE_DEVICES=") { hipVals = append(hipVals, e) } } @@ -66,7 +89,27 @@ func TestServerEnv_Good_FiltersExistingHIPVisibleDevices(t *testing.T) { env := serverEnv() var hipVals []string for _, e := range env { - if strings.HasPrefix(e, "HIP_VISIBLE_DEVICES=") { + if core.HasPrefix(e, "HIP_VISIBLE_DEVICES=") { + hipVals = append(hipVals, e) + } + } + assert.Equal(t, []string{"HIP_VISIBLE_DEVICES=0"}, hipVals) +} + +func TestServerEnv_Bad_NilEnviron(t *testing.T) { + // serverEnv must never panic even when called with unusual env state. + // It always appends HIP_VISIBLE_DEVICES=0 regardless of ambient env. + env := serverEnv() + assert.NotEmpty(t, env) +} + +func TestServerEnv_Ugly_MultipleHIPEntries(t *testing.T) { + // Even if multiple HIP_VISIBLE_DEVICES entries somehow existed, only one must remain. + t.Setenv("HIP_VISIBLE_DEVICES", "2,3") + env := serverEnv() + var hipVals []string + for _, e := range env { + if core.HasPrefix(e, "HIP_VISIBLE_DEVICES=") { hipVals = append(hipVals, e) } } @@ -75,12 +118,32 @@ func TestServerEnv_Good_FiltersExistingHIPVisibleDevices(t *testing.T) { func TestAvailable_Good(t *testing.T) { b := &rocmBackend{} - if _, err := os.Stat("/dev/kfd"); err != nil { + if !(&core.Fs{}).New("/").Exists("/dev/kfd") { t.Skip("no ROCm hardware") } assert.True(t, b.Available()) } +func TestAvailable_Bad_NoDevice(t *testing.T) { + // When /dev/kfd is absent, Available must return false. + if (&core.Fs{}).New("/").Exists("/dev/kfd") { + t.Skip("ROCm device present — skip no-device bad path on this machine") + } + b := &rocmBackend{} + assert.False(t, b.Available()) +} + +func TestAvailable_Ugly_NoLlamaServer(t *testing.T) { + // Even with /dev/kfd present, Available must be false if llama-server is missing. + // We can't create /dev/kfd in a test, so verify the condition via findLlamaServer. + t.Setenv("PATH", "") + t.Setenv("ROCM_LLAMA_SERVER_PATH", "") + b := &rocmBackend{} + // If kfd is present but llama-server missing, Available returns false. + // If kfd is absent, Available also returns false. Either way, not panic. + _ = b.Available() +} + func TestServerAlive_Good_Running(t *testing.T) { s := &server{exited: make(chan struct{})} assert.True(t, s.alive()) @@ -93,6 +156,39 @@ func TestServerAlive_Good_Exited(t *testing.T) { assert.False(t, s.alive()) } +func TestServerAlive_Bad_NilExited(t *testing.T) { + // A server with a nil exited channel panics — this documents the contract: + // exited must always be initialised before use. + // We test the well-formed bad state: exited closed with nil exitErr. + exited := make(chan struct{}) + close(exited) + s := &server{exited: exited, exitErr: nil} + assert.False(t, s.alive()) +} + +func TestServerAlive_Ugly_ExitedAfterStart(t *testing.T) { + // alive transitions from true to false when the channel is closed. + exited := make(chan struct{}) + s := &server{exited: exited} + assert.True(t, s.alive()) + close(exited) + assert.False(t, s.alive()) +} + +func TestGenerate_Good_YieldsNoTokensOnEmptyServer(t *testing.T) { + // A newly-created dead server produces zero tokens and records an error. + exited := make(chan struct{}) + close(exited) + s := &server{exited: exited, exitErr: nil} + m := &rocmModel{server: s} + + var tokenCount int + for range m.Generate(context.Background(), "hello") { + tokenCount++ + } + assert.Equal(t, 0, tokenCount) +} + func TestGenerate_Bad_ServerDead(t *testing.T) { exited := make(chan struct{}) close(exited) @@ -102,14 +198,39 @@ func TestGenerate_Bad_ServerDead(t *testing.T) { } m := &rocmModel{server: s} - var count int + var tokenCount int for range m.Generate(context.Background(), "hello") { - count++ + tokenCount++ } - assert.Equal(t, 0, count) + assert.Equal(t, 0, tokenCount) assert.ErrorContains(t, m.Err(), "server has exited") } +func TestGenerate_Ugly_ErrClearedBetweenCalls(t *testing.T) { + // Err is reset to nil on each Generate call start. + exited := make(chan struct{}) + close(exited) + s := &server{exited: exited, exitErr: coreerr.E("test", "killed", nil)} + m := &rocmModel{server: s} + + for range m.Generate(context.Background(), "first") { + } + assert.Error(t, m.Err()) +} + +func TestStartServer_Good_RejectsNegativeLayers(t *testing.T) { + // gpuLayers=-1 must be converted to 999 (all layers on GPU). + // /bin/false exits immediately so we observe the retry behaviour. + _, err := startServer("/bin/false", "/nonexistent/model.gguf", -1, 0, 0) + require.Error(t, err) + assert.Contains(t, err.Error(), "failed after 3 attempts") +} + +func TestStartServer_Bad_BinaryNotFound(t *testing.T) { + _, err := startServer("/nonexistent/binary", "/nonexistent/model.gguf", 0, 0, 0) + require.Error(t, err) +} + func TestStartServer_Ugly_RetriesOnProcessExit(t *testing.T) { // /bin/false starts successfully but exits immediately with code 1. // startServer should retry up to 3 times, then fail. @@ -118,6 +239,20 @@ func TestStartServer_Ugly_RetriesOnProcessExit(t *testing.T) { assert.Contains(t, err.Error(), "failed after 3 attempts") } +func TestChat_Good_EmptyMessages(t *testing.T) { + // Chat with an empty message list on a dead server yields no tokens. + exited := make(chan struct{}) + close(exited) + s := &server{exited: exited, exitErr: nil} + m := &rocmModel{server: s} + + var tokenCount int + for range m.Chat(context.Background(), nil) { + tokenCount++ + } + assert.Equal(t, 0, tokenCount) +} + func TestChat_Bad_ServerDead(t *testing.T) { exited := make(chan struct{}) close(exited) @@ -128,10 +263,29 @@ func TestChat_Bad_ServerDead(t *testing.T) { m := &rocmModel{server: s} msgs := []inference.Message{{Role: "user", Content: "hello"}} - var count int + var tokenCount int for range m.Chat(context.Background(), msgs) { - count++ + tokenCount++ } - assert.Equal(t, 0, count) + assert.Equal(t, 0, tokenCount) assert.ErrorContains(t, m.Err(), "server has exited") } + +func TestChat_Ugly_MultipleRolesOnDeadServer(t *testing.T) { + // Chat with multiple roles on a dead server must still return safely. + exited := make(chan struct{}) + close(exited) + s := &server{exited: exited, exitErr: coreerr.E("test", "killed", nil)} + m := &rocmModel{server: s} + + msgs := []inference.Message{ + {Role: "system", Content: "You are helpful."}, + {Role: "user", Content: "Hello"}, + } + var tokenCount int + for range m.Chat(context.Background(), msgs) { + tokenCount++ + } + assert.Equal(t, 0, tokenCount) + assert.Error(t, m.Err()) +} diff --git a/vram.go b/vram.go index 99a5c5d..47ce93b 100644 --- a/vram.go +++ b/vram.go @@ -3,10 +3,9 @@ package rocm import ( - "os" - "path/filepath" "strconv" - "strings" + + "dappco.re/go/core" coreerr "forge.lthn.ai/core/go-log" ) @@ -15,17 +14,10 @@ import ( // It identifies the dGPU by selecting the card with the largest VRAM total, // which avoids hardcoding card numbers (e.g. card0=iGPU, card1=dGPU on Ryzen). // -// Note: total and used are read non-atomically from sysfs; transient -// inconsistencies are possible under heavy allocation churn. -// // info, err := rocm.GetVRAMInfo() -// fmt.Printf("VRAM: %d MiB used / %d MiB total (free: %d MiB)", -// info.Used/(1024*1024), info.Total/(1024*1024), info.Free/(1024*1024)) +// // info.Total == 17179869184, info.Used == 2147483648, info.Free == 15032385536 func GetVRAMInfo() (VRAMInfo, error) { - cards, err := filepath.Glob("/sys/class/drm/card[0-9]*/device/mem_info_vram_total") - if err != nil { - return VRAMInfo{}, coreerr.E("rocm.GetVRAMInfo", "glob vram sysfs", err) - } + cards := core.PathGlob("/sys/class/drm/card[0-9]*/device/mem_info_vram_total") if len(cards) == 0 { return VRAMInfo{}, coreerr.E("rocm.GetVRAMInfo", "no GPU VRAM info found in sysfs", nil) } @@ -40,7 +32,7 @@ func GetVRAMInfo() (VRAMInfo, error) { } if total > bestTotal { bestTotal = total - bestDir = filepath.Dir(totalPath) + bestDir = core.PathDir(totalPath) } } @@ -48,27 +40,27 @@ func GetVRAMInfo() (VRAMInfo, error) { return VRAMInfo{}, coreerr.E("rocm.GetVRAMInfo", "no readable VRAM sysfs entries", nil) } - used, err := readSysfsUint64(filepath.Join(bestDir, "mem_info_vram_used")) + used, err := readSysfsUint64(core.Path(bestDir, "mem_info_vram_used")) if err != nil { return VRAMInfo{}, coreerr.E("rocm.GetVRAMInfo", "read vram used", err) } - free := uint64(0) + freeBytes := uint64(0) if bestTotal > used { - free = bestTotal - used + freeBytes = bestTotal - used } return VRAMInfo{ Total: bestTotal, Used: used, - Free: free, + Free: freeBytes, }, nil } func readSysfsUint64(path string) (uint64, error) { - data, err := os.ReadFile(path) - if err != nil { - return 0, err + result := (&core.Fs{}).New("/").Read(path) + if !result.OK { + return 0, coreerr.E("rocm.readSysfsUint64", "read sysfs file", result.Value.(error)) } - return strconv.ParseUint(strings.TrimSpace(string(data)), 10, 64) + return strconv.ParseUint(core.Trim(result.Value.(string)), 10, 64) } diff --git a/vram_test.go b/vram_test.go index 325dfbc..fc7dd2a 100644 --- a/vram_test.go +++ b/vram_test.go @@ -4,16 +4,17 @@ package rocm import ( "os" - "path/filepath" "testing" + "dappco.re/go/core" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestReadSysfsUint64_Good(t *testing.T) { dir := t.TempDir() - path := filepath.Join(dir, "test_value") + path := core.Path(dir, "test_value") require.NoError(t, os.WriteFile(path, []byte("17163091968\n"), 0644)) val, err := readSysfsUint64(path) @@ -28,7 +29,7 @@ func TestReadSysfsUint64_Bad_NotFound(t *testing.T) { func TestReadSysfsUint64_Bad_InvalidContent(t *testing.T) { dir := t.TempDir() - path := filepath.Join(dir, "bad_value") + path := core.Path(dir, "bad_value") require.NoError(t, os.WriteFile(path, []byte("not-a-number\n"), 0644)) _, err := readSysfsUint64(path) @@ -37,13 +38,24 @@ func TestReadSysfsUint64_Bad_InvalidContent(t *testing.T) { func TestReadSysfsUint64_Bad_EmptyFile(t *testing.T) { dir := t.TempDir() - path := filepath.Join(dir, "empty_value") + path := core.Path(dir, "empty_value") require.NoError(t, os.WriteFile(path, []byte(""), 0644)) _, err := readSysfsUint64(path) assert.Error(t, err) } +func TestReadSysfsUint64_Ugly_WhitespaceAroundValue(t *testing.T) { + // sysfs files often contain a trailing newline; readSysfsUint64 must trim it. + dir := t.TempDir() + path := core.Path(dir, "whitespace_value") + require.NoError(t, os.WriteFile(path, []byte(" 17163091968 \n"), 0644)) + + val, err := readSysfsUint64(path) + require.NoError(t, err) + assert.Equal(t, uint64(17163091968), val) +} + func TestGetVRAMInfo_Good(t *testing.T) { info, err := GetVRAMInfo() if err != nil { @@ -55,3 +67,23 @@ func TestGetVRAMInfo_Good(t *testing.T) { assert.Greater(t, info.Used, uint64(0), "expected some VRAM in use") assert.Equal(t, info.Total-info.Used, info.Free, "Free should equal Total-Used") } + +func TestGetVRAMInfo_Bad_NoGPU(t *testing.T) { + // On non-ROCm machines GetVRAMInfo must return an error, not panic. + _, err := GetVRAMInfo() + if err == nil { + t.Skip("ROCm sysfs present — not applicable on this machine") + } + assert.Error(t, err) +} + +func TestGetVRAMInfo_Ugly_FreeComputed(t *testing.T) { + // Invariant: Free == Total - Used whenever Total >= Used. + info, err := GetVRAMInfo() + if err != nil { + t.Skipf("no VRAM sysfs info available: %v", err) + } + if info.Total >= info.Used { + assert.Equal(t, info.Total-info.Used, info.Free) + } +}