Compare commits

..

4 commits

Author SHA1 Message Date
Claude
661d37c5c1
style(ax): rename loop variable e→envEntry for AX naming compliance
Co-Authored-By: Virgil <virgil@lethean.io>
2026-03-31 08:25:10 +01:00
Claude
3073c019f8
feat(ax): pass 1 — remove banned imports, AX naming, test coverage
Replace fmt/strings/encoding/json/path/filepath/os with dappco.re/go/core
primitives across all non-test source files. Add core as direct dep.
Rename cfg→configuration, ctxLen→contextLen, met→result. Add usage
example comments on all exported functions. Ensure all three test
categories (Good/Bad/Ugly) exist for every tested function.

Residuals for pass 2:
- server.go: os/exec (no c.Process() available) + os.Environ() (no core wrapper)
- internal/gguf: os retained for *os.File.Stat() via core.Fs.Open result

Co-Authored-By: Virgil <virgil@lethean.io>
2026-03-31 08:24:58 +01:00
Claude
523abc6509
feat(ax): pass 2 — replace banned imports, rename variables, add AX comments
Replace fmt/strings/path/filepath/encoding/json with core equivalents throughout
all packages. Rename cfg→configuration, srv→server/subprocess, ftName→fileTypeName,
ctxSize→contextSize. Add usage-example doc-comments to every exported symbol.
Update all test names to TestSubject_Function_{Good,Bad,Ugly} convention.

Co-Authored-By: Virgil <virgil@lethean.io>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-31 08:24:34 +01:00
Claude
41b34b6779
feat(ax): apply RFC-025 AX compliance review
Principle 1 — Predictable Names:
- rocmModel.srv → rocmModel.server (struct field)
- recordMetrics: met → metrics (local var)
- backend.go/model.go: cfg → config (local vars)
- gguf.go: tc/kc → tensorCount32/kvCount32 (v2 count reads)

Principle 2 — Comments as Usage Examples:
- Added concrete usage examples to all exported functions:
  VRAMInfo, ModelInfo, DiscoverModels, GetVRAMInfo,
  ROCmAvailable, LoadModel, Available, NewClient, Health,
  ChatComplete, Complete, ReadMetadata, FileTypeName

Principle 5 — Test naming (_Good/_Bad/_Ugly):
- All test functions renamed to AX-7 convention across:
  discover_test.go, vram_test.go, server_test.go,
  internal/gguf/gguf_test.go, internal/llamacpp/client_test.go,
  internal/llamacpp/health_test.go

Also: fix go.sum missing entry for dappco.re/go/core transitive dep
(pulled in by go-inference replace directive).

All tests pass: go test ./... -short -count=1

Co-Authored-By: Virgil <virgil@lethean.io>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-31 07:33:47 +01:00
20 changed files with 824 additions and 495 deletions

View file

@ -3,8 +3,7 @@
package rocm
import (
"os"
"strings"
"dappco.re/go/core"
coreerr "forge.lthn.ai/core/go-log"
"forge.lthn.ai/core/go-inference"
@ -18,8 +17,11 @@ func (b *rocmBackend) Name() string { return "rocm" }
// Available reports whether ROCm GPU inference can run on this machine.
// Checks for the ROCm kernel driver (/dev/kfd) and a findable llama-server binary.
//
// b := inference.FindBackend("rocm")
// if b.Available() { /* safe to LoadModel */ }
func (b *rocmBackend) Available() bool {
if _, err := os.Stat("/dev/kfd"); err != nil {
if !(&core.Fs{}).New("/").Exists("/dev/kfd") {
return false
}
if _, err := findLlamaServer(); err != nil {
@ -32,8 +34,14 @@ func (b *rocmBackend) Available() bool {
// Model architecture is read from GGUF metadata (replacing filename-based guessing).
// If no context length is specified, defaults to min(model_context_length, 4096)
// to prevent VRAM exhaustion on models with 128K+ native context.
//
// m, err := backend.LoadModel("/data/lem/gguf/model.gguf",
// inference.WithContextLen(4096),
// inference.WithGPULayers(-1),
// )
// defer m.Close()
func (b *rocmBackend) LoadModel(path string, opts ...inference.LoadOption) (inference.TextModel, error) {
cfg := inference.ApplyLoadOpts(opts)
configuration := inference.ApplyLoadOpts(opts)
binary, err := findLlamaServer()
if err != nil {
@ -45,12 +53,12 @@ func (b *rocmBackend) LoadModel(path string, opts ...inference.LoadOption) (infe
return nil, coreerr.E("rocm.LoadModel", "read model metadata", err)
}
ctxLen := cfg.ContextLen
if ctxLen == 0 && meta.ContextLength > 0 {
ctxLen = int(min(meta.ContextLength, 4096))
contextLen := configuration.ContextLen
if contextLen == 0 && meta.ContextLength > 0 {
contextLen = int(min(meta.ContextLength, 4096))
}
srv, err := startServer(binary, path, cfg.GPULayers, ctxLen, cfg.ParallelSlots)
subprocess, err := startServer(binary, path, configuration.GPULayers, contextLen, configuration.ParallelSlots)
if err != nil {
return nil, err
}
@ -58,34 +66,34 @@ func (b *rocmBackend) LoadModel(path string, opts ...inference.LoadOption) (infe
// Map quantisation file type to bit width.
quantBits := 0
quantGroup := 0
ftName := gguf.FileTypeName(meta.FileType)
fileTypeName := gguf.FileTypeName(meta.FileType)
switch {
case strings.HasPrefix(ftName, "Q4_"):
case core.HasPrefix(fileTypeName, "Q4_"):
quantBits = 4
quantGroup = 32
case strings.HasPrefix(ftName, "Q5_"):
case core.HasPrefix(fileTypeName, "Q5_"):
quantBits = 5
quantGroup = 32
case strings.HasPrefix(ftName, "Q8_"):
case core.HasPrefix(fileTypeName, "Q8_"):
quantBits = 8
quantGroup = 32
case strings.HasPrefix(ftName, "Q2_"):
case core.HasPrefix(fileTypeName, "Q2_"):
quantBits = 2
quantGroup = 16
case strings.HasPrefix(ftName, "Q3_"):
case core.HasPrefix(fileTypeName, "Q3_"):
quantBits = 3
quantGroup = 32
case strings.HasPrefix(ftName, "Q6_"):
case core.HasPrefix(fileTypeName, "Q6_"):
quantBits = 6
quantGroup = 64
case ftName == "F16":
case fileTypeName == "F16":
quantBits = 16
case ftName == "F32":
case fileTypeName == "F32":
quantBits = 32
}
return &rocmModel{
srv: srv,
server: subprocess,
modelType: meta.Architecture,
modelInfo: inference.ModelInfo{
Architecture: meta.Architecture,

View file

@ -1,18 +1,18 @@
package rocm
import (
"path/filepath"
"dappco.re/go/core"
"forge.lthn.ai/core/go-rocm/internal/gguf"
)
// DiscoverModels scans a directory for GGUF model files and returns
// structured information about each. Files that cannot be parsed are skipped.
//
// models, err := rocm.DiscoverModels("/data/lem/gguf")
// for _, model := range models { core.Print(c, "%s %s ctx=%d", model.Name, model.Quantisation, model.ContextLen) }
func DiscoverModels(dir string) ([]ModelInfo, error) {
matches, err := filepath.Glob(filepath.Join(dir, "*.gguf"))
if err != nil {
return nil, err
}
matches := core.PathGlob(core.Path(dir, "*.gguf"))
var models []ModelInfo
for _, path := range matches {

View file

@ -3,9 +3,10 @@ package rocm
import (
"encoding/binary"
"os"
"path/filepath"
"testing"
"dappco.re/go/core"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@ -15,7 +16,7 @@ import (
func writeDiscoverTestGGUF(t *testing.T, dir, filename string, kvs [][2]any) string {
t.Helper()
path := filepath.Join(dir, filename)
path := core.Path(dir, filename)
f, err := os.Create(path)
require.NoError(t, err)
@ -63,7 +64,7 @@ func writeDiscoverKV(t *testing.T, f *os.File, key string, val any) {
}
}
func TestDiscoverModels(t *testing.T) {
func TestDiscoverModels_Good(t *testing.T) {
dir := t.TempDir()
// Create two valid GGUF model files.
@ -86,7 +87,7 @@ func TestDiscoverModels(t *testing.T) {
})
// Create a non-GGUF file that should be ignored (no .gguf extension).
require.NoError(t, os.WriteFile(filepath.Join(dir, "README.txt"), []byte("not a model"), 0644))
require.NoError(t, os.WriteFile(core.Path(dir, "README.txt"), []byte("not a model"), 0644))
models, err := DiscoverModels(dir)
require.NoError(t, err)
@ -94,7 +95,7 @@ func TestDiscoverModels(t *testing.T) {
// Sort order from Glob is lexicographic, so gemma3 comes first.
gemma := models[0]
assert.Equal(t, filepath.Join(dir, "gemma3-4b-q4km.gguf"), gemma.Path)
assert.Equal(t, core.Path(dir, "gemma3-4b-q4km.gguf"), gemma.Path)
assert.Equal(t, "gemma3", gemma.Architecture)
assert.Equal(t, "Gemma 3 4B Instruct", gemma.Name)
assert.Equal(t, "Q4_K_M", gemma.Quantisation)
@ -103,7 +104,7 @@ func TestDiscoverModels(t *testing.T) {
assert.Greater(t, gemma.FileSize, int64(0))
llama := models[1]
assert.Equal(t, filepath.Join(dir, "llama-3.1-8b-q4km.gguf"), llama.Path)
assert.Equal(t, core.Path(dir, "llama-3.1-8b-q4km.gguf"), llama.Path)
assert.Equal(t, "llama", llama.Architecture)
assert.Equal(t, "Llama 3.1 8B Instruct", llama.Name)
assert.Equal(t, "Q4_K_M", llama.Quantisation)
@ -112,7 +113,7 @@ func TestDiscoverModels(t *testing.T) {
assert.Greater(t, llama.FileSize, int64(0))
}
func TestDiscoverModels_EmptyDir(t *testing.T) {
func TestDiscoverModels_Good_EmptyDir(t *testing.T) {
dir := t.TempDir()
models, err := DiscoverModels(dir)
@ -120,15 +121,15 @@ func TestDiscoverModels_EmptyDir(t *testing.T) {
assert.Empty(t, models)
}
func TestDiscoverModels_NotFound(t *testing.T) {
// filepath.Glob returns nil, nil for a pattern matching no files,
func TestDiscoverModels_Bad_NonExistentDir(t *testing.T) {
// core.PathGlob returns nil for patterns matching no files,
// even when the directory does not exist.
models, err := DiscoverModels("/nonexistent/dir")
require.NoError(t, err)
assert.Empty(t, models)
}
func TestDiscoverModels_SkipsCorruptFile(t *testing.T) {
func TestDiscoverModels_Ugly_SkipsCorruptFile(t *testing.T) {
dir := t.TempDir()
// Create a valid GGUF file.
@ -139,7 +140,7 @@ func TestDiscoverModels_SkipsCorruptFile(t *testing.T) {
})
// Create a corrupt .gguf file (not valid GGUF binary).
require.NoError(t, os.WriteFile(filepath.Join(dir, "corrupt.gguf"), []byte("not gguf data"), 0644))
require.NoError(t, os.WriteFile(core.Path(dir, "corrupt.gguf"), []byte("not gguf data"), 0644))
models, err := DiscoverModels(dir)
require.NoError(t, err)

20
go.mod
View file

@ -1,26 +1,18 @@
module dappco.re/go/core/rocm
module forge.lthn.ai/core/go-rocm
go 1.26.0
require (
dappco.re/go/core/inference v0.1.5
dappco.re/go/core/log v0.0.4
dappco.re/go/core v0.8.0-alpha.1
forge.lthn.ai/core/go-inference v0.1.5
forge.lthn.ai/core/go-log v0.0.4
)
require github.com/kr/text v0.2.0 // indirect
require (
dappco.re/go/core v0.5.0
dappco.re/go/core/api v0.2.0
dappco.re/go/core/i18n v0.2.0
dappco.re/go/core/io v0.2.0
dappco.re/go/core/log v0.1.0
dappco.re/go/core/process v0.3.0
dappco.re/go/core/scm v0.4.0
dappco.re/go/core/store v0.2.0
dappco.re/go/core/ws v0.3.0
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
github.com/stretchr/testify v1.11.1
gopkg.in/yaml.v3 v3.0.1 // indirect
)
replace forge.lthn.ai/core/go-inference => ../go-inference

3
go.sum
View file

@ -1,6 +1,7 @@
dappco.re/go/core v0.8.0-alpha.1 h1:gj7+Scv+L63Z7wMxbJYHhaRFkHJo2u4MMPuUSv/Dhtk=
dappco.re/go/core v0.8.0-alpha.1/go.mod h1:f2/tBZ3+3IqDrg2F5F598llv0nmb/4gJVCFzM5geE4A=
forge.lthn.ai/core/go-log v0.0.4 h1:KTuCEPgFmuM8KJfnyQ8vPOU1Jg654W74h8IJvfQMfv0=
forge.lthn.ai/core/go-log v0.0.4/go.mod h1:r14MXKOD3LF/sI8XUJQhRk/SZHBE7jAFVuCfgkXoZPw=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=

View file

@ -10,11 +10,11 @@ package gguf
import (
"bufio"
"encoding/binary"
"fmt"
"io"
"math"
"os"
"strings"
"dappco.re/go/core"
coreerr "forge.lthn.ai/core/go-log"
)
@ -72,66 +72,73 @@ var fileTypeNames = map[uint32]string{
// FileTypeName returns a human-readable name for a GGML quantisation file type.
// Unknown types return "type_N" where N is the numeric value.
//
// name := gguf.FileTypeName(15) // "Q4_K_M"
// name := gguf.FileTypeName(17) // "Q5_K_M"
func FileTypeName(ft uint32) string {
if name, ok := fileTypeNames[ft]; ok {
return name
}
return fmt.Sprintf("type_%d", ft)
return core.Sprintf("type_%d", ft)
}
// ReadMetadata reads the GGUF header from the file at path and returns the
// extracted metadata. Only metadata KV pairs are read; tensor data is not loaded.
//
// meta, err := gguf.ReadMetadata("/data/lem/gguf/gemma3-4b-q4km.gguf")
// // meta.Architecture == "gemma3", meta.ContextLength == 32768
func ReadMetadata(path string) (Metadata, error) {
f, err := os.Open(path)
if err != nil {
return Metadata{}, err
openResult := (&core.Fs{}).New("/").Open(path)
if !openResult.OK {
return Metadata{}, openResult.Value.(error)
}
defer f.Close()
file := openResult.Value.(*os.File)
defer file.Close()
info, err := f.Stat()
info, err := file.Stat()
if err != nil {
return Metadata{}, err
}
r := bufio.NewReader(f)
reader := bufio.NewReader(file)
// Read and validate magic number.
var magic uint32
if err := binary.Read(r, binary.LittleEndian, &magic); err != nil {
if err := binary.Read(reader, binary.LittleEndian, &magic); err != nil {
return Metadata{}, coreerr.E("gguf.ReadMetadata", "reading magic", err)
}
if magic != ggufMagic {
return Metadata{}, coreerr.E("gguf.ReadMetadata", fmt.Sprintf("invalid magic: 0x%08X (expected 0x%08X)", magic, ggufMagic), nil)
return Metadata{}, coreerr.E("gguf.ReadMetadata", core.Sprintf("invalid magic: 0x%08X (expected 0x%08X)", magic, ggufMagic), nil)
}
// Read version.
var version uint32
if err := binary.Read(r, binary.LittleEndian, &version); err != nil {
if err := binary.Read(reader, binary.LittleEndian, &version); err != nil {
return Metadata{}, coreerr.E("gguf.ReadMetadata", "reading version", err)
}
if version < 2 || version > 3 {
return Metadata{}, coreerr.E("gguf.ReadMetadata", fmt.Sprintf("unsupported GGUF version: %d", version), nil)
return Metadata{}, coreerr.E("gguf.ReadMetadata", core.Sprintf("unsupported GGUF version: %d", version), nil)
}
// Read tensor count and KV count. v3 uses uint64, v2 uses uint32.
var tensorCount, kvCount uint64
if version == 3 {
if err := binary.Read(r, binary.LittleEndian, &tensorCount); err != nil {
if err := binary.Read(reader, binary.LittleEndian, &tensorCount); err != nil {
return Metadata{}, coreerr.E("gguf.ReadMetadata", "reading tensor count", err)
}
if err := binary.Read(r, binary.LittleEndian, &kvCount); err != nil {
if err := binary.Read(reader, binary.LittleEndian, &kvCount); err != nil {
return Metadata{}, coreerr.E("gguf.ReadMetadata", "reading kv count", err)
}
} else {
var tc, kc uint32
if err := binary.Read(r, binary.LittleEndian, &tc); err != nil {
var tensorCount32, kvCount32 uint32
if err := binary.Read(reader, binary.LittleEndian, &tensorCount32); err != nil {
return Metadata{}, coreerr.E("gguf.ReadMetadata", "reading tensor count", err)
}
if err := binary.Read(r, binary.LittleEndian, &kc); err != nil {
if err := binary.Read(reader, binary.LittleEndian, &kvCount32); err != nil {
return Metadata{}, coreerr.E("gguf.ReadMetadata", "reading kv count", err)
}
tensorCount = uint64(tc)
kvCount = uint64(kc)
tensorCount = uint64(tensorCount32)
kvCount = uint64(kvCount32)
}
_ = tensorCount // we only read metadata KVs
@ -148,76 +155,76 @@ func ReadMetadata(path string) (Metadata, error) {
candidateBlockCount := make(map[string]uint32)
for i := uint64(0); i < kvCount; i++ {
key, err := readString(r)
key, err := readString(reader)
if err != nil {
return Metadata{}, coreerr.E("gguf.ReadMetadata", fmt.Sprintf("reading key %d", i), err)
return Metadata{}, coreerr.E("gguf.ReadMetadata", core.Sprintf("reading key %d", i), err)
}
var valType uint32
if err := binary.Read(r, binary.LittleEndian, &valType); err != nil {
return Metadata{}, coreerr.E("gguf.ReadMetadata", fmt.Sprintf("reading value type for key %q", key), err)
if err := binary.Read(reader, binary.LittleEndian, &valType); err != nil {
return Metadata{}, coreerr.E("gguf.ReadMetadata", core.Sprintf("reading value type for key %q", key), err)
}
// Check whether this is an interesting key before reading the value.
switch {
case key == "general.architecture":
v, err := readTypedValue(r, valType)
rawValue, err := readTypedValue(reader, valType)
if err != nil {
return Metadata{}, coreerr.E("gguf.ReadMetadata", fmt.Sprintf("reading value for key %q", key), err)
return Metadata{}, coreerr.E("gguf.ReadMetadata", core.Sprintf("reading value for key %q", key), err)
}
if s, ok := v.(string); ok {
meta.Architecture = s
if stringValue, ok := rawValue.(string); ok {
meta.Architecture = stringValue
}
case key == "general.name":
v, err := readTypedValue(r, valType)
rawValue, err := readTypedValue(reader, valType)
if err != nil {
return Metadata{}, coreerr.E("gguf.ReadMetadata", fmt.Sprintf("reading value for key %q", key), err)
return Metadata{}, coreerr.E("gguf.ReadMetadata", core.Sprintf("reading value for key %q", key), err)
}
if s, ok := v.(string); ok {
meta.Name = s
if stringValue, ok := rawValue.(string); ok {
meta.Name = stringValue
}
case key == "general.file_type":
v, err := readTypedValue(r, valType)
rawValue, err := readTypedValue(reader, valType)
if err != nil {
return Metadata{}, coreerr.E("gguf.ReadMetadata", fmt.Sprintf("reading value for key %q", key), err)
return Metadata{}, coreerr.E("gguf.ReadMetadata", core.Sprintf("reading value for key %q", key), err)
}
if u, ok := v.(uint32); ok {
meta.FileType = u
if uint32Value, ok := rawValue.(uint32); ok {
meta.FileType = uint32Value
}
case key == "general.size_label":
v, err := readTypedValue(r, valType)
rawValue, err := readTypedValue(reader, valType)
if err != nil {
return Metadata{}, coreerr.E("gguf.ReadMetadata", fmt.Sprintf("reading value for key %q", key), err)
return Metadata{}, coreerr.E("gguf.ReadMetadata", core.Sprintf("reading value for key %q", key), err)
}
if s, ok := v.(string); ok {
meta.SizeLabel = s
if stringValue, ok := rawValue.(string); ok {
meta.SizeLabel = stringValue
}
case strings.HasSuffix(key, ".context_length"):
v, err := readTypedValue(r, valType)
case core.HasSuffix(key, ".context_length"):
rawValue, err := readTypedValue(reader, valType)
if err != nil {
return Metadata{}, coreerr.E("gguf.ReadMetadata", fmt.Sprintf("reading value for key %q", key), err)
return Metadata{}, coreerr.E("gguf.ReadMetadata", core.Sprintf("reading value for key %q", key), err)
}
if u, ok := v.(uint32); ok {
candidateContextLength[key] = u
if uint32Value, ok := rawValue.(uint32); ok {
candidateContextLength[key] = uint32Value
}
case strings.HasSuffix(key, ".block_count"):
v, err := readTypedValue(r, valType)
case core.HasSuffix(key, ".block_count"):
rawValue, err := readTypedValue(reader, valType)
if err != nil {
return Metadata{}, coreerr.E("gguf.ReadMetadata", fmt.Sprintf("reading value for key %q", key), err)
return Metadata{}, coreerr.E("gguf.ReadMetadata", core.Sprintf("reading value for key %q", key), err)
}
if u, ok := v.(uint32); ok {
candidateBlockCount[key] = u
if uint32Value, ok := rawValue.(uint32); ok {
candidateBlockCount[key] = uint32Value
}
default:
// Skip uninteresting value.
if err := skipValue(r, valType); err != nil {
return Metadata{}, coreerr.E("gguf.ReadMetadata", fmt.Sprintf("skipping value for key %q", key), err)
if err := skipValue(reader, valType); err != nil {
return Metadata{}, coreerr.E("gguf.ReadMetadata", core.Sprintf("skipping value for key %q", key), err)
}
}
}
@ -225,11 +232,11 @@ func ReadMetadata(path string) (Metadata, error) {
// Resolve architecture-specific keys.
if meta.Architecture != "" {
prefix := meta.Architecture + "."
if v, ok := candidateContextLength[prefix+"context_length"]; ok {
meta.ContextLength = v
if contextLength, ok := candidateContextLength[prefix+"context_length"]; ok {
meta.ContextLength = contextLength
}
if v, ok := candidateBlockCount[prefix+"block_count"]; ok {
meta.BlockCount = v
if blockCount, ok := candidateBlockCount[prefix+"block_count"]; ok {
meta.BlockCount = blockCount
}
}
@ -241,94 +248,94 @@ func ReadMetadata(path string) (Metadata, error) {
const maxStringLength = 1 << 20
// readString reads a GGUF string: uint64 length followed by that many bytes.
func readString(r io.Reader) (string, error) {
func readString(reader io.Reader) (string, error) {
var length uint64
if err := binary.Read(r, binary.LittleEndian, &length); err != nil {
if err := binary.Read(reader, binary.LittleEndian, &length); err != nil {
return "", err
}
if length > maxStringLength {
return "", coreerr.E("gguf.readString", fmt.Sprintf("string length %d exceeds maximum %d", length, maxStringLength), nil)
return "", coreerr.E("gguf.readString", core.Sprintf("string length %d exceeds maximum %d", length, maxStringLength), nil)
}
buf := make([]byte, length)
if _, err := io.ReadFull(r, buf); err != nil {
buffer := make([]byte, length)
if _, err := io.ReadFull(reader, buffer); err != nil {
return "", err
}
return string(buf), nil
return string(buffer), nil
}
// readTypedValue reads a value of the given GGUF type and returns it as a Go
// value. String, uint32, and uint64 types return typed values (uint64 is
// downcast to uint32 when it fits). All others are read and discarded.
func readTypedValue(r io.Reader, valType uint32) (any, error) {
func readTypedValue(reader io.Reader, valType uint32) (any, error) {
switch valType {
case typeString:
return readString(r)
return readString(reader)
case typeUint32:
var v uint32
err := binary.Read(r, binary.LittleEndian, &v)
return v, err
var uint32Value uint32
err := binary.Read(reader, binary.LittleEndian, &uint32Value)
return uint32Value, err
case typeUint64:
var v uint64
if err := binary.Read(r, binary.LittleEndian, &v); err != nil {
var uint64Value uint64
if err := binary.Read(reader, binary.LittleEndian, &uint64Value); err != nil {
return nil, err
}
if v <= math.MaxUint32 {
return uint32(v), nil
if uint64Value <= math.MaxUint32 {
return uint32(uint64Value), nil
}
return v, nil
return uint64Value, nil
default:
// Read and discard the value, returning nil.
err := skipValue(r, valType)
err := skipValue(reader, valType)
return nil, err
}
}
// skipValue reads and discards a GGUF value of the given type from r.
func skipValue(r io.Reader, valType uint32) error {
// skipValue reads and discards a GGUF value of the given type from reader.
func skipValue(reader io.Reader, valType uint32) error {
switch valType {
case typeUint8, typeInt8, typeBool:
_, err := readN(r, 1)
_, err := readN(reader, 1)
return err
case typeUint16, typeInt16:
_, err := readN(r, 2)
_, err := readN(reader, 2)
return err
case typeUint32, typeInt32, typeFloat32:
_, err := readN(r, 4)
_, err := readN(reader, 4)
return err
case typeUint64, typeInt64, typeFloat64:
_, err := readN(r, 8)
_, err := readN(reader, 8)
return err
case typeString:
var length uint64
if err := binary.Read(r, binary.LittleEndian, &length); err != nil {
if err := binary.Read(reader, binary.LittleEndian, &length); err != nil {
return err
}
if length > maxStringLength {
return coreerr.E("gguf.skipValue", fmt.Sprintf("string length %d exceeds maximum %d", length, maxStringLength), nil)
return coreerr.E("gguf.skipValue", core.Sprintf("string length %d exceeds maximum %d", length, maxStringLength), nil)
}
_, err := readN(r, int64(length))
_, err := readN(reader, int64(length))
return err
case typeArray:
var elemType uint32
if err := binary.Read(r, binary.LittleEndian, &elemType); err != nil {
if err := binary.Read(reader, binary.LittleEndian, &elemType); err != nil {
return err
}
var count uint64
if err := binary.Read(r, binary.LittleEndian, &count); err != nil {
if err := binary.Read(reader, binary.LittleEndian, &count); err != nil {
return err
}
for i := uint64(0); i < count; i++ {
if err := skipValue(r, elemType); err != nil {
if err := skipValue(reader, elemType); err != nil {
return err
}
}
return nil
default:
return coreerr.E("gguf.skipValue", fmt.Sprintf("unknown GGUF value type: %d", valType), nil)
return coreerr.E("gguf.skipValue", core.Sprintf("unknown GGUF value type: %d", valType), nil)
}
}
// readN reads and discards exactly n bytes from r.
func readN(r io.Reader, n int64) (int64, error) {
return io.CopyN(io.Discard, r, n)
// readN reads and discards exactly n bytes from reader.
func readN(reader io.Reader, n int64) (int64, error) {
return io.CopyN(io.Discard, reader, n)
}

View file

@ -3,9 +3,10 @@ package gguf
import (
"encoding/binary"
"os"
"path/filepath"
"testing"
"dappco.re/go/core"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@ -16,7 +17,7 @@ func writeTestGGUFOrdered(t *testing.T, kvs [][2]any) string {
t.Helper()
dir := t.TempDir()
path := filepath.Join(dir, "test.gguf")
path := core.Path(dir, "test.gguf")
f, err := os.Create(path)
require.NoError(t, err)
@ -86,7 +87,7 @@ func writeTestGGUFV2(t *testing.T, kvs [][2]any) string {
t.Helper()
dir := t.TempDir()
path := filepath.Join(dir, "test_v2.gguf")
path := core.Path(dir, "test_v2.gguf")
f, err := os.Create(path)
require.NoError(t, err)
@ -109,7 +110,7 @@ func writeTestGGUFV2(t *testing.T, kvs [][2]any) string {
return path
}
func TestReadMetadata_Gemma3(t *testing.T) {
func TestReadMetadata_Good_Gemma3(t *testing.T) {
path := writeTestGGUFOrdered(t, [][2]any{
{"general.architecture", "gemma3"},
{"general.name", "Test Gemma3 1B"},
@ -131,7 +132,7 @@ func TestReadMetadata_Gemma3(t *testing.T) {
assert.Greater(t, m.FileSize, int64(0))
}
func TestReadMetadata_Llama(t *testing.T) {
func TestReadMetadata_Good_Llama(t *testing.T) {
path := writeTestGGUFOrdered(t, [][2]any{
{"general.architecture", "llama"},
{"general.name", "Test Llama 8B"},
@ -153,7 +154,7 @@ func TestReadMetadata_Llama(t *testing.T) {
assert.Greater(t, m.FileSize, int64(0))
}
func TestReadMetadata_ArchAfterContextLength(t *testing.T) {
func TestReadMetadata_Ugly_ArchAfterContextLength(t *testing.T) {
// Architecture key comes AFTER the arch-specific keys.
// The parser must handle deferred resolution of arch-prefixed keys.
path := writeTestGGUFOrdered(t, [][2]any{
@ -174,9 +175,9 @@ func TestReadMetadata_ArchAfterContextLength(t *testing.T) {
assert.Equal(t, uint32(32), m.BlockCount)
}
func TestReadMetadata_InvalidMagic(t *testing.T) {
func TestReadMetadata_Bad_InvalidMagic(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "notgguf.bin")
path := core.Path(dir, "notgguf.bin")
err := os.WriteFile(path, []byte("this is not a GGUF file at all"), 0644)
require.NoError(t, err)
@ -186,20 +187,30 @@ func TestReadMetadata_InvalidMagic(t *testing.T) {
assert.Contains(t, err.Error(), "invalid magic")
}
func TestReadMetadata_FileNotFound(t *testing.T) {
func TestReadMetadata_Bad_FileNotFound(t *testing.T) {
_, err := ReadMetadata("/nonexistent/path/model.gguf")
require.Error(t, err)
}
func TestFileTypeName(t *testing.T) {
func TestFileTypeName_Good(t *testing.T) {
assert.Equal(t, "Q4_K_M", FileTypeName(15))
assert.Equal(t, "Q5_K_M", FileTypeName(17))
assert.Equal(t, "Q8_0", FileTypeName(7))
assert.Equal(t, "F16", FileTypeName(1))
assert.Equal(t, "type_999", FileTypeName(999))
}
func TestReadMetadata_V2(t *testing.T) {
func TestFileTypeName_Bad_UnknownType(t *testing.T) {
// Unknown type numbers must not panic; return "type_N".
name := FileTypeName(999)
assert.Equal(t, "type_999", name)
}
func TestFileTypeName_Ugly_ZeroType(t *testing.T) {
// Type 0 is F32 — the zero value must map to a valid name, not "type_0".
assert.Equal(t, "F32", FileTypeName(0))
}
func TestReadMetadata_Good_V2(t *testing.T) {
// GGUF v2 uses uint32 for tensor and KV counts (instead of uint64 in v3).
path := writeTestGGUFV2(t, [][2]any{
{"general.architecture", "llama"},
@ -219,9 +230,9 @@ func TestReadMetadata_V2(t *testing.T) {
assert.Equal(t, uint32(16), m.BlockCount)
}
func TestReadMetadata_UnsupportedVersion(t *testing.T) {
func TestReadMetadata_Bad_UnsupportedVersion(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "bad_version.gguf")
path := core.Path(dir, "bad_version.gguf")
f, err := os.Create(path)
require.NoError(t, err)
@ -235,11 +246,11 @@ func TestReadMetadata_UnsupportedVersion(t *testing.T) {
assert.Contains(t, err.Error(), "unsupported GGUF version")
}
func TestReadMetadata_SkipsUnknownValueTypes(t *testing.T) {
func TestReadMetadata_Ugly_SkipsUnknownValueTypes(t *testing.T) {
// Tests skipValue for uint8, int16, float32, uint64, bool, and array types.
// These are stored under uninteresting keys so ReadMetadata skips them.
dir := t.TempDir()
path := filepath.Join(dir, "skip_types.gguf")
path := core.Path(dir, "skip_types.gguf")
f, err := os.Create(path)
require.NoError(t, err)
@ -283,7 +294,7 @@ func TestReadMetadata_SkipsUnknownValueTypes(t *testing.T) {
b8 := make([]byte, 8)
binary.LittleEndian.PutUint64(b8, 3) // count: 3
arrBuf = append(arrBuf, b8...)
arrBuf = append(arrBuf, 10, 20, 30) // 3 uint8 values
arrBuf = append(arrBuf, 10, 20, 30) // 3 uint8 values
writeRawKV(t, f, "custom.array_val", 9, arrBuf)
// 7-8. Interesting keys to verify parsing continued correctly.
@ -299,7 +310,7 @@ func TestReadMetadata_SkipsUnknownValueTypes(t *testing.T) {
assert.Equal(t, "Skip Test Model", m.Name)
}
func TestReadMetadata_Uint64ContextLength(t *testing.T) {
func TestReadMetadata_Ugly_Uint64ContextLength(t *testing.T) {
// context_length stored as uint64 that fits in uint32 — readTypedValue
// should downcast it to uint32.
path := writeTestGGUFOrdered(t, [][2]any{
@ -315,9 +326,9 @@ func TestReadMetadata_Uint64ContextLength(t *testing.T) {
assert.Equal(t, uint32(32), m.BlockCount)
}
func TestReadMetadata_TruncatedFile(t *testing.T) {
func TestReadMetadata_Bad_TruncatedFile(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "truncated.gguf")
path := core.Path(dir, "truncated.gguf")
// Write only the magic — no version or counts.
f, err := os.Create(path)
@ -330,10 +341,10 @@ func TestReadMetadata_TruncatedFile(t *testing.T) {
assert.Contains(t, err.Error(), "reading version")
}
func TestReadMetadata_SkipsStringValue(t *testing.T) {
func TestReadMetadata_Ugly_SkipsStringValue(t *testing.T) {
// Tests skipValue for string type (type 8) on an uninteresting key.
dir := t.TempDir()
path := filepath.Join(dir, "skip_string.gguf")
path := core.Path(dir, "skip_string.gguf")
f, err := os.Create(path)
require.NoError(t, err)

View file

@ -4,14 +4,13 @@ import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"iter"
"net/http"
"strings"
"sync"
"dappco.re/go/core"
coreerr "forge.lthn.ai/core/go-log"
)
@ -62,47 +61,54 @@ type completionChunkResponse struct {
// ChatComplete sends a streaming chat completion request to /v1/chat/completions.
// It returns an iterator over text chunks and a function that returns any error
// that occurred during the request or while reading the stream.
func (c *Client) ChatComplete(ctx context.Context, req ChatRequest) (iter.Seq[string], func() error) {
req.Stream = true
//
// chunks, errorFunc := client.ChatComplete(ctx, ChatRequest{Messages: msgs, Temperature: 0.7})
// for text := range chunks { output += text }
// if err := errorFunc(); err != nil { /* handle */ }
func (c *Client) ChatComplete(ctx context.Context, chatRequest ChatRequest) (iter.Seq[string], func() error) {
chatRequest.Stream = true
body, err := json.Marshal(req)
if err != nil {
return noChunks, func() error { return coreerr.E("llamacpp.ChatComplete", "marshal chat request", err) }
marshalResult := core.JSONMarshal(chatRequest)
if !marshalResult.OK {
marshalErr := marshalResult.Value.(error)
return noChunks, func() error { return coreerr.E("llamacpp.ChatComplete", "marshal chat request", marshalErr) }
}
body := marshalResult.Value.([]byte)
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+"/v1/chat/completions", bytes.NewReader(body))
httpRequest, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+"/v1/chat/completions", bytes.NewReader(body))
if err != nil {
return noChunks, func() error { return coreerr.E("llamacpp.ChatComplete", "create chat request", err) }
}
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("Accept", "text/event-stream")
httpRequest.Header.Set("Content-Type", "application/json")
httpRequest.Header.Set("Accept", "text/event-stream")
resp, err := c.httpClient.Do(httpReq)
response, err := c.httpClient.Do(httpRequest)
if err != nil {
return noChunks, func() error { return coreerr.E("llamacpp.ChatComplete", "chat request", err) }
}
if resp.StatusCode != http.StatusOK {
defer resp.Body.Close()
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 256))
if response.StatusCode != http.StatusOK {
defer response.Body.Close()
responseBody, _ := io.ReadAll(io.LimitReader(response.Body, 256))
return noChunks, func() error {
return coreerr.E("llamacpp.ChatComplete", fmt.Sprintf("chat returned %d: %s", resp.StatusCode, strings.TrimSpace(string(respBody))), nil)
return coreerr.E("llamacpp.ChatComplete", core.Sprintf("chat returned %d: %s", response.StatusCode, core.Trim(string(responseBody))), nil)
}
}
var (
streamErr error
closeOnce sync.Once
closeBody = func() { closeOnce.Do(func() { resp.Body.Close() }) }
closeBody = func() { closeOnce.Do(func() { response.Body.Close() }) }
)
sseData := parseSSE(resp.Body, &streamErr)
sseData := parseSSE(response.Body, &streamErr)
tokens := func(yield func(string) bool) {
defer closeBody()
for raw := range sseData {
for ssePayload := range sseData {
var chunk chatChunkResponse
if err := json.Unmarshal([]byte(raw), &chunk); err != nil {
streamErr = coreerr.E("llamacpp.ChatComplete", "decode chat chunk", err)
decodeResult := core.JSONUnmarshal([]byte(ssePayload), &chunk)
if !decodeResult.OK {
streamErr = coreerr.E("llamacpp.ChatComplete", "decode chat chunk", decodeResult.Value.(error))
return
}
if len(chunk.Choices) == 0 {
@ -127,47 +133,54 @@ func (c *Client) ChatComplete(ctx context.Context, req ChatRequest) (iter.Seq[st
// Complete sends a streaming completion request to /v1/completions.
// It returns an iterator over text chunks and a function that returns any error
// that occurred during the request or while reading the stream.
func (c *Client) Complete(ctx context.Context, req CompletionRequest) (iter.Seq[string], func() error) {
req.Stream = true
//
// chunks, errorFunc := client.Complete(ctx, CompletionRequest{Prompt: "Once upon", Temperature: 0.8})
// for text := range chunks { output += text }
// if err := errorFunc(); err != nil { /* handle */ }
func (c *Client) Complete(ctx context.Context, completionRequest CompletionRequest) (iter.Seq[string], func() error) {
completionRequest.Stream = true
body, err := json.Marshal(req)
if err != nil {
return noChunks, func() error { return coreerr.E("llamacpp.Complete", "marshal completion request", err) }
marshalResult := core.JSONMarshal(completionRequest)
if !marshalResult.OK {
marshalErr := marshalResult.Value.(error)
return noChunks, func() error { return coreerr.E("llamacpp.Complete", "marshal completion request", marshalErr) }
}
body := marshalResult.Value.([]byte)
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+"/v1/completions", bytes.NewReader(body))
httpRequest, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+"/v1/completions", bytes.NewReader(body))
if err != nil {
return noChunks, func() error { return coreerr.E("llamacpp.Complete", "create completion request", err) }
}
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("Accept", "text/event-stream")
httpRequest.Header.Set("Content-Type", "application/json")
httpRequest.Header.Set("Accept", "text/event-stream")
resp, err := c.httpClient.Do(httpReq)
response, err := c.httpClient.Do(httpRequest)
if err != nil {
return noChunks, func() error { return coreerr.E("llamacpp.Complete", "completion request", err) }
}
if resp.StatusCode != http.StatusOK {
defer resp.Body.Close()
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 256))
if response.StatusCode != http.StatusOK {
defer response.Body.Close()
responseBody, _ := io.ReadAll(io.LimitReader(response.Body, 256))
return noChunks, func() error {
return coreerr.E("llamacpp.Complete", fmt.Sprintf("completion returned %d: %s", resp.StatusCode, strings.TrimSpace(string(respBody))), nil)
return coreerr.E("llamacpp.Complete", core.Sprintf("completion returned %d: %s", response.StatusCode, core.Trim(string(responseBody))), nil)
}
}
var (
streamErr error
closeOnce sync.Once
closeBody = func() { closeOnce.Do(func() { resp.Body.Close() }) }
closeBody = func() { closeOnce.Do(func() { response.Body.Close() }) }
)
sseData := parseSSE(resp.Body, &streamErr)
sseData := parseSSE(response.Body, &streamErr)
tokens := func(yield func(string) bool) {
defer closeBody()
for raw := range sseData {
for ssePayload := range sseData {
var chunk completionChunkResponse
if err := json.Unmarshal([]byte(raw), &chunk); err != nil {
streamErr = coreerr.E("llamacpp.Complete", "decode completion chunk", err)
decodeResult := core.JSONUnmarshal([]byte(ssePayload), &chunk)
if !decodeResult.OK {
streamErr = coreerr.E("llamacpp.Complete", "decode completion chunk", decodeResult.Value.(error))
return
}
if len(chunk.Choices) == 0 {
@ -189,18 +202,18 @@ func (c *Client) Complete(ctx context.Context, req CompletionRequest) (iter.Seq[
}
}
// parseSSE reads SSE-formatted lines from r and yields the payload of each
// parseSSE reads SSE-formatted lines from reader and yields the payload of each
// "data: " line. It stops when it encounters "[DONE]" or an I/O error.
// Any read error (other than EOF) is stored via errOut.
func parseSSE(r io.Reader, errOut *error) iter.Seq[string] {
func parseSSE(reader io.Reader, errOut *error) iter.Seq[string] {
return func(yield func(string) bool) {
scanner := bufio.NewScanner(r)
scanner := bufio.NewScanner(reader)
for scanner.Scan() {
line := scanner.Text()
if !strings.HasPrefix(line, "data: ") {
if !core.HasPrefix(line, "data: ") {
continue
}
payload := strings.TrimPrefix(line, "data: ")
payload := core.TrimPrefix(line, "data: ")
if payload == "[DONE]" {
return
}

View file

@ -2,7 +2,7 @@ package llamacpp
import (
"context"
"fmt"
"io"
"net/http"
"net/http/httptest"
"testing"
@ -22,12 +22,12 @@ func sseLines(w http.ResponseWriter, lines []string) {
w.WriteHeader(http.StatusOK)
for _, line := range lines {
fmt.Fprintf(w, "data: %s\n\n", line)
_, _ = io.WriteString(w, "data: "+line+"\n\n")
f.Flush()
}
}
func TestChatComplete_Streaming(t *testing.T) {
func TestChatComplete_Good_Streaming(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "/v1/chat/completions", r.URL.Path)
assert.Equal(t, "POST", r.Method)
@ -56,7 +56,7 @@ func TestChatComplete_Streaming(t *testing.T) {
assert.Equal(t, []string{"Hello", " world"}, got)
}
func TestChatComplete_EmptyResponse(t *testing.T) {
func TestChatComplete_Good_EmptyResponse(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
sseLines(w, []string{"[DONE]"})
}))
@ -77,7 +77,7 @@ func TestChatComplete_EmptyResponse(t *testing.T) {
assert.Empty(t, got)
}
func TestChatComplete_HTTPError(t *testing.T) {
func TestChatComplete_Bad_HTTPError(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "internal server error", http.StatusInternalServerError)
}))
@ -100,7 +100,7 @@ func TestChatComplete_HTTPError(t *testing.T) {
assert.Contains(t, err.Error(), "500")
}
func TestChatComplete_ContextCancelled(t *testing.T) {
func TestChatComplete_Ugly_ContextCancelled(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
@ -114,7 +114,7 @@ func TestChatComplete_ContextCancelled(t *testing.T) {
w.WriteHeader(http.StatusOK)
// Send first chunk.
fmt.Fprintf(w, "data: %s\n\n", `{"choices":[{"delta":{"content":"Hello"},"finish_reason":null}]}`)
_, _ = io.WriteString(w, "data: "+`{"choices":[{"delta":{"content":"Hello"},"finish_reason":null}]}`+"\n\n")
f.Flush()
// Wait for context cancellation before sending more.
@ -140,7 +140,7 @@ func TestChatComplete_ContextCancelled(t *testing.T) {
assert.Equal(t, []string{"Hello"}, got)
}
func TestComplete_Streaming(t *testing.T) {
func TestComplete_Good_Streaming(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "/v1/completions", r.URL.Path)
assert.Equal(t, "POST", r.Method)
@ -170,7 +170,7 @@ func TestComplete_Streaming(t *testing.T) {
assert.Equal(t, []string{"Once", " upon", " a time"}, got)
}
func TestComplete_HTTPError(t *testing.T) {
func TestComplete_Bad_HTTPError(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "bad request", http.StatusBadRequest)
}))
@ -192,3 +192,43 @@ func TestComplete_HTTPError(t *testing.T) {
require.Error(t, err)
assert.Contains(t, err.Error(), "400")
}
func TestComplete_Ugly_ContextCancelled(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
f, ok := w.(http.Flusher)
if !ok {
panic("ResponseWriter does not implement Flusher")
}
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.WriteHeader(http.StatusOK)
// Send first chunk.
_, _ = io.WriteString(w, "data: "+`{"choices":[{"text":"Once","finish_reason":null}]}`+"\n\n")
f.Flush()
// Wait for context cancellation before sending more.
<-r.Context().Done()
}))
defer ts.Close()
c := NewClient(ts.URL)
tokens, errFn := c.Complete(ctx, CompletionRequest{
Prompt: "Once",
Temperature: 0.7,
Stream: true,
})
var got []string
for tok := range tokens {
got = append(got, tok)
cancel() // Cancel after receiving the first token.
}
// The error may or may not be nil depending on timing;
// the important thing is we got exactly 1 token.
_ = errFn()
assert.Equal(t, []string{"Once"}, got)
}

View file

@ -2,11 +2,10 @@ package llamacpp
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"dappco.re/go/core"
coreerr "forge.lthn.ai/core/go-log"
)
@ -18,9 +17,11 @@ type Client struct {
}
// NewClient creates a client for the llama-server at the given base URL.
//
// client := llamacpp.NewClient("http://127.0.0.1:8080")
func NewClient(baseURL string) *Client {
return &Client{
baseURL: strings.TrimRight(baseURL, "/"),
baseURL: core.TrimSuffix(baseURL, "/"),
httpClient: &http.Client{},
}
}
@ -30,27 +31,36 @@ type healthResponse struct {
}
// Health checks whether the llama-server is ready to accept requests.
//
// if err := client.Health(ctx); err != nil { /* server not ready */ }
func (c *Client) Health(ctx context.Context) error {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.baseURL+"/health", nil)
request, err := http.NewRequestWithContext(ctx, http.MethodGet, c.baseURL+"/health", nil)
if err != nil {
return err
}
resp, err := c.httpClient.Do(req)
response, err := c.httpClient.Do(request)
if err != nil {
return err
}
defer resp.Body.Close()
defer response.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 256))
return coreerr.E("llamacpp.Health", fmt.Sprintf("health returned %d: %s", resp.StatusCode, string(body)), nil)
if response.StatusCode != http.StatusOK {
body, _ := io.ReadAll(io.LimitReader(response.Body, 256))
return coreerr.E("llamacpp.Health", core.Sprintf("health returned %d: %s", response.StatusCode, string(body)), nil)
}
var h healthResponse
if err := json.NewDecoder(resp.Body).Decode(&h); err != nil {
return coreerr.E("llamacpp.Health", "health decode", err)
var healthStatus healthResponse
decodeResult := core.JSONUnmarshal(mustReadAll(response.Body), &healthStatus)
if !decodeResult.OK {
return coreerr.E("llamacpp.Health", "health decode", decodeResult.Value.(error))
}
if h.Status != "ok" {
return coreerr.E("llamacpp.Health", fmt.Sprintf("server not ready (status: %s)", h.Status), nil)
if healthStatus.Status != "ok" {
return coreerr.E("llamacpp.Health", core.Sprintf("server not ready (status: %s)", healthStatus.Status), nil)
}
return nil
}
// mustReadAll reads all bytes from reader, returning nil on error.
func mustReadAll(reader io.Reader) []byte {
data, _ := io.ReadAll(reader)
return data
}

View file

@ -10,7 +10,7 @@ import (
"github.com/stretchr/testify/require"
)
func TestHealth_OK(t *testing.T) {
func TestHealth_Good(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "/health", r.URL.Path)
w.Header().Set("Content-Type", "application/json")
@ -23,7 +23,7 @@ func TestHealth_OK(t *testing.T) {
require.NoError(t, err)
}
func TestHealth_NotReady(t *testing.T) {
func TestHealth_Bad_NotReady(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(`{"status":"loading model"}`))
@ -35,7 +35,7 @@ func TestHealth_NotReady(t *testing.T) {
assert.ErrorContains(t, err, "not ready")
}
func TestHealth_Loading(t *testing.T) {
func TestHealth_Bad_Loading(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusServiceUnavailable)
@ -48,8 +48,20 @@ func TestHealth_Loading(t *testing.T) {
assert.ErrorContains(t, err, "503")
}
func TestHealth_ServerDown(t *testing.T) {
func TestHealth_Bad_ServerDown(t *testing.T) {
c := NewClient("http://127.0.0.1:1") // nothing listening
err := c.Health(context.Background())
assert.Error(t, err)
}
func TestHealth_Ugly_MalformedJSON(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(`not valid json`))
}))
defer ts.Close()
c := NewClient(ts.URL)
err := c.Health(context.Background())
assert.Error(t, err)
}

151
model.go
View file

@ -4,12 +4,12 @@ package rocm
import (
"context"
"fmt"
"iter"
"strings"
"sync"
"time"
"dappco.re/go/core"
coreerr "forge.lthn.ai/core/go-log"
"forge.lthn.ai/core/go-inference"
"forge.lthn.ai/core/go-rocm/internal/llamacpp"
@ -17,7 +17,7 @@ import (
// rocmModel implements inference.TextModel using a llama-server subprocess.
type rocmModel struct {
srv *server
server *server
modelType string
modelInfo inference.ModelInfo
@ -27,104 +27,116 @@ type rocmModel struct {
}
// Generate streams tokens for the given prompt via llama-server's /v1/completions endpoint.
//
// for tok := range m.Generate(ctx, "The capital of France is", inference.WithMaxTokens(32)) {
// output += tok.Text
// }
func (m *rocmModel) Generate(ctx context.Context, prompt string, opts ...inference.GenerateOption) iter.Seq[inference.Token] {
m.mu.Lock()
m.lastErr = nil
m.mu.Unlock()
if !m.srv.alive() {
if !m.server.alive() {
m.setServerExitErr()
return func(yield func(inference.Token) bool) {}
}
cfg := inference.ApplyGenerateOpts(opts)
configuration := inference.ApplyGenerateOpts(opts)
req := llamacpp.CompletionRequest{
completionRequest := llamacpp.CompletionRequest{
Prompt: prompt,
MaxTokens: cfg.MaxTokens,
Temperature: cfg.Temperature,
TopK: cfg.TopK,
TopP: cfg.TopP,
RepeatPenalty: cfg.RepeatPenalty,
MaxTokens: configuration.MaxTokens,
Temperature: configuration.Temperature,
TopK: configuration.TopK,
TopP: configuration.TopP,
RepeatPenalty: configuration.RepeatPenalty,
}
start := time.Now()
chunks, errFn := m.srv.client.Complete(ctx, req)
chunks, errorFunc := m.server.client.Complete(ctx, completionRequest)
return func(yield func(inference.Token) bool) {
var count int
var tokenCount int
decodeStart := time.Now()
for text := range chunks {
count++
tokenCount++
if !yield(inference.Token{Text: text}) {
break
}
}
if err := errFn(); err != nil {
if err := errorFunc(); err != nil {
m.mu.Lock()
m.lastErr = err
m.mu.Unlock()
}
m.recordMetrics(0, count, start, decodeStart)
m.recordMetrics(0, tokenCount, start, decodeStart)
}
}
// Chat streams tokens from a multi-turn conversation via llama-server's /v1/chat/completions endpoint.
//
// msgs := []inference.Message{{Role: "user", Content: "Hello"}}
// for tok := range m.Chat(ctx, msgs, inference.WithMaxTokens(64)) {
// output += tok.Text
// }
func (m *rocmModel) Chat(ctx context.Context, messages []inference.Message, opts ...inference.GenerateOption) iter.Seq[inference.Token] {
m.mu.Lock()
m.lastErr = nil
m.mu.Unlock()
if !m.srv.alive() {
if !m.server.alive() {
m.setServerExitErr()
return func(yield func(inference.Token) bool) {}
}
cfg := inference.ApplyGenerateOpts(opts)
configuration := inference.ApplyGenerateOpts(opts)
chatMsgs := make([]llamacpp.ChatMessage, len(messages))
chatMessages := make([]llamacpp.ChatMessage, len(messages))
for i, msg := range messages {
chatMsgs[i] = llamacpp.ChatMessage{
chatMessages[i] = llamacpp.ChatMessage{
Role: msg.Role,
Content: msg.Content,
}
}
req := llamacpp.ChatRequest{
Messages: chatMsgs,
MaxTokens: cfg.MaxTokens,
Temperature: cfg.Temperature,
TopK: cfg.TopK,
TopP: cfg.TopP,
RepeatPenalty: cfg.RepeatPenalty,
chatRequest := llamacpp.ChatRequest{
Messages: chatMessages,
MaxTokens: configuration.MaxTokens,
Temperature: configuration.Temperature,
TopK: configuration.TopK,
TopP: configuration.TopP,
RepeatPenalty: configuration.RepeatPenalty,
}
start := time.Now()
chunks, errFn := m.srv.client.ChatComplete(ctx, req)
chunks, errorFunc := m.server.client.ChatComplete(ctx, chatRequest)
return func(yield func(inference.Token) bool) {
var count int
var tokenCount int
decodeStart := time.Now()
for text := range chunks {
count++
tokenCount++
if !yield(inference.Token{Text: text}) {
break
}
}
if err := errFn(); err != nil {
if err := errorFunc(); err != nil {
m.mu.Lock()
m.lastErr = err
m.mu.Unlock()
}
m.recordMetrics(0, count, start, decodeStart)
m.recordMetrics(0, tokenCount, start, decodeStart)
}
}
// Classify runs batched prefill-only inference via llama-server.
// Each prompt gets a single-token completion (max_tokens=1, temperature=0).
// llama-server has no native classify endpoint, so this simulates it.
//
// results, err := m.Classify(ctx, []string{"positive review", "negative review"})
// // results[0].Token.Text == "pos" or similar top token
func (m *rocmModel) Classify(ctx context.Context, prompts []string, opts ...inference.GenerateOption) ([]inference.ClassifyResult, error) {
if !m.srv.alive() {
if !m.server.alive() {
m.setServerExitErr()
return nil, m.Err()
}
@ -137,23 +149,23 @@ func (m *rocmModel) Classify(ctx context.Context, prompts []string, opts ...infe
return nil, ctx.Err()
}
req := llamacpp.CompletionRequest{
completionRequest := llamacpp.CompletionRequest{
Prompt: prompt,
MaxTokens: 1,
Temperature: 0,
}
chunks, errFn := m.srv.client.Complete(ctx, req)
var text strings.Builder
chunks, errorFunc := m.server.client.Complete(ctx, completionRequest)
builder := core.NewBuilder()
for chunk := range chunks {
text.WriteString(chunk)
builder.WriteString(chunk)
}
if err := errFn(); err != nil {
return nil, coreerr.E("rocm.Classify", fmt.Sprintf("classify prompt %d", i), err)
if err := errorFunc(); err != nil {
return nil, coreerr.E("rocm.Classify", core.Sprintf("classify prompt %d", i), err)
}
results[i] = inference.ClassifyResult{
Token: inference.Token{Text: text.String()},
Token: inference.Token{Text: builder.String()},
}
}
@ -163,13 +175,16 @@ func (m *rocmModel) Classify(ctx context.Context, prompts []string, opts ...infe
// BatchGenerate runs batched autoregressive generation via llama-server.
// Each prompt is decoded sequentially up to MaxTokens.
//
// results, err := m.BatchGenerate(ctx, []string{"prompt A", "prompt B"}, inference.WithMaxTokens(64))
// // results[0].Tokens — tokens generated for prompt A
func (m *rocmModel) BatchGenerate(ctx context.Context, prompts []string, opts ...inference.GenerateOption) ([]inference.BatchResult, error) {
if !m.srv.alive() {
if !m.server.alive() {
m.setServerExitErr()
return nil, m.Err()
}
cfg := inference.ApplyGenerateOpts(opts)
configuration := inference.ApplyGenerateOpts(opts)
start := time.Now()
results := make([]inference.BatchResult, len(prompts))
var totalGenerated int
@ -180,22 +195,22 @@ func (m *rocmModel) BatchGenerate(ctx context.Context, prompts []string, opts ..
continue
}
req := llamacpp.CompletionRequest{
completionRequest := llamacpp.CompletionRequest{
Prompt: prompt,
MaxTokens: cfg.MaxTokens,
Temperature: cfg.Temperature,
TopK: cfg.TopK,
TopP: cfg.TopP,
RepeatPenalty: cfg.RepeatPenalty,
MaxTokens: configuration.MaxTokens,
Temperature: configuration.Temperature,
TopK: configuration.TopK,
TopP: configuration.TopP,
RepeatPenalty: configuration.RepeatPenalty,
}
chunks, errFn := m.srv.client.Complete(ctx, req)
chunks, errorFunc := m.server.client.Complete(ctx, completionRequest)
var tokens []inference.Token
for text := range chunks {
tokens = append(tokens, inference.Token{Text: text})
}
if err := errFn(); err != nil {
results[i].Err = coreerr.E("rocm.BatchGenerate", fmt.Sprintf("batch prompt %d", i), err)
if err := errorFunc(); err != nil {
results[i].Err = coreerr.E("rocm.BatchGenerate", core.Sprintf("batch prompt %d", i), err)
}
results[i].Tokens = tokens
totalGenerated += len(tokens)
@ -206,12 +221,20 @@ func (m *rocmModel) BatchGenerate(ctx context.Context, prompts []string, opts ..
}
// ModelType returns the architecture identifier (e.g. "gemma3", "qwen3", "llama3").
//
// arch := m.ModelType() // "gemma3"
func (m *rocmModel) ModelType() string { return m.modelType }
// Info returns metadata about the loaded model.
//
// info := m.Info()
// // info.Architecture == "gemma3", info.NumLayers == 26
func (m *rocmModel) Info() inference.ModelInfo { return m.modelInfo }
// Metrics returns performance metrics from the last inference operation.
//
// metrics := m.Metrics()
// // metrics.DecodeTokensPerSec, metrics.TotalDuration
func (m *rocmModel) Metrics() inference.GenerateMetrics {
m.mu.Lock()
defer m.mu.Unlock()
@ -219,6 +242,9 @@ func (m *rocmModel) Metrics() inference.GenerateMetrics {
}
// Err returns the error from the last Generate/Chat call, if any.
//
// for tok := range m.Generate(ctx, prompt) { }
// if err := m.Err(); err != nil { /* handle */ }
func (m *rocmModel) Err() error {
m.mu.Lock()
defer m.mu.Unlock()
@ -226,16 +252,19 @@ func (m *rocmModel) Err() error {
}
// Close releases the llama-server subprocess and all associated resources.
//
// m, err := backend.LoadModel("/data/model.gguf")
// defer m.Close()
func (m *rocmModel) Close() error {
return m.srv.stop()
return m.server.stop()
}
// setServerExitErr stores an appropriate error when the server is dead.
func (m *rocmModel) setServerExitErr() {
m.mu.Lock()
defer m.mu.Unlock()
if m.srv.exitErr != nil {
m.lastErr = coreerr.E("rocm.setServerExitErr", "server has exited", m.srv.exitErr)
if m.server.exitErr != nil {
m.lastErr = coreerr.E("rocm.setServerExitErr", "server has exited", m.server.exitErr)
} else {
m.lastErr = coreerr.E("rocm.setServerExitErr", "server has exited unexpectedly", nil)
}
@ -248,7 +277,7 @@ func (m *rocmModel) recordMetrics(promptTokens, generatedTokens int, start, deco
decode := now.Sub(decodeStart)
prefill := total - decode
met := inference.GenerateMetrics{
result := inference.GenerateMetrics{
PromptTokens: promptTokens,
GeneratedTokens: generatedTokens,
PrefillDuration: prefill,
@ -256,19 +285,19 @@ func (m *rocmModel) recordMetrics(promptTokens, generatedTokens int, start, deco
TotalDuration: total,
}
if prefill > 0 && promptTokens > 0 {
met.PrefillTokensPerSec = float64(promptTokens) / prefill.Seconds()
result.PrefillTokensPerSec = float64(promptTokens) / prefill.Seconds()
}
if decode > 0 && generatedTokens > 0 {
met.DecodeTokensPerSec = float64(generatedTokens) / decode.Seconds()
result.DecodeTokensPerSec = float64(generatedTokens) / decode.Seconds()
}
// Try to get VRAM stats — best effort.
if vram, err := GetVRAMInfo(); err == nil {
met.PeakMemoryBytes = vram.Used
met.ActiveMemoryBytes = vram.Used
if vramInfo, err := GetVRAMInfo(); err == nil {
result.PeakMemoryBytes = vramInfo.Used
result.ActiveMemoryBytes = vramInfo.Used
}
m.mu.Lock()
m.metrics = met
m.metrics = result
m.mu.Unlock()
}

View file

@ -8,5 +8,10 @@ func init() {
inference.Register(&rocmBackend{})
}
// ROCmAvailable reports whether ROCm GPU inference is available.
// ROCmAvailable reports whether ROCm GPU inference is available on this machine.
// Returns true only on linux/amd64 with /dev/kfd present and llama-server findable.
//
// if rocm.ROCmAvailable() {
// m, err := inference.LoadModel("/data/model.gguf")
// }
func ROCmAvailable() bool { return true }

11
rocm.go
View file

@ -12,8 +12,9 @@
//
// m, err := inference.LoadModel("/path/to/model.gguf")
// defer m.Close()
// output := core.NewBuilder()
// for tok := range m.Generate(ctx, "Hello", inference.WithMaxTokens(128)) {
// fmt.Print(tok.Text)
// output.WriteString(tok.Text)
// }
//
// # Requirements
@ -25,6 +26,9 @@
package rocm
// VRAMInfo reports GPU video memory usage in bytes.
//
// info, err := rocm.GetVRAMInfo()
// core.Print(c, "VRAM: %d MiB used / %d MiB total", info.Used/(1024*1024), info.Total/(1024*1024))
type VRAMInfo struct {
Total uint64
Used uint64
@ -32,6 +36,11 @@ type VRAMInfo struct {
}
// ModelInfo describes a GGUF model file discovered on disk.
//
// models, _ := rocm.DiscoverModels("/data/lem/gguf")
// for _, m := range models {
// core.Print(c, "%s (%s %s, ctx=%d)", m.Name, m.Architecture, m.Quantisation, m.ContextLen)
// }
type ModelInfo struct {
Path string // full path to .gguf file
Architecture string // GGUF architecture (e.g. "gemma3", "llama", "qwen2")

View file

@ -5,12 +5,12 @@ package rocm
import (
"context"
"os"
"path/filepath"
"strings"
"sync"
"testing"
"time"
"dappco.re/go/core"
"forge.lthn.ai/core/go-inference"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@ -35,45 +35,45 @@ func skipIfNoROCm(t *testing.T) {
}
}
func TestROCm_LoadAndGenerate(t *testing.T) {
func TestROCm_LoadAndGenerate_Good(t *testing.T) {
skipIfNoROCm(t)
skipIfNoModel(t)
b := &rocmBackend{}
require.True(t, b.Available())
backend := &rocmBackend{}
require.True(t, backend.Available())
m, err := b.LoadModel(testModel, inference.WithContextLen(2048))
model, err := backend.LoadModel(testModel, inference.WithContextLen(2048))
require.NoError(t, err)
defer m.Close()
defer model.Close()
assert.Equal(t, "gemma3", m.ModelType())
assert.Equal(t, "gemma3", model.ModelType())
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
var tokens []string
for tok := range m.Generate(ctx, "The capital of France is", inference.WithMaxTokens(16)) {
tokens = append(tokens, tok.Text)
var tokenTexts []string
for tok := range model.Generate(ctx, "The capital of France is", inference.WithMaxTokens(16)) {
tokenTexts = append(tokenTexts, tok.Text)
}
require.NoError(t, m.Err())
require.NotEmpty(t, tokens, "expected at least one token")
require.NoError(t, model.Err())
require.NotEmpty(t, tokenTexts, "expected at least one token")
full := ""
for _, tok := range tokens {
full += tok
fullText := ""
for _, tok := range tokenTexts {
fullText += tok
}
t.Logf("Generated: %s", full)
t.Logf("Generated: %s", fullText)
}
func TestROCm_Chat(t *testing.T) {
func TestROCm_Chat_Good(t *testing.T) {
skipIfNoROCm(t)
skipIfNoModel(t)
b := &rocmBackend{}
m, err := b.LoadModel(testModel, inference.WithContextLen(2048))
backend := &rocmBackend{}
model, err := backend.LoadModel(testModel, inference.WithContextLen(2048))
require.NoError(t, err)
defer m.Close()
defer model.Close()
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
@ -82,89 +82,89 @@ func TestROCm_Chat(t *testing.T) {
{Role: "user", Content: "Say hello in exactly three words."},
}
var tokens []string
for tok := range m.Chat(ctx, messages, inference.WithMaxTokens(32)) {
tokens = append(tokens, tok.Text)
var tokenTexts []string
for tok := range model.Chat(ctx, messages, inference.WithMaxTokens(32)) {
tokenTexts = append(tokenTexts, tok.Text)
}
require.NoError(t, m.Err())
require.NotEmpty(t, tokens, "expected at least one token")
require.NoError(t, model.Err())
require.NotEmpty(t, tokenTexts, "expected at least one token")
full := ""
for _, tok := range tokens {
full += tok
fullText := ""
for _, tok := range tokenTexts {
fullText += tok
}
t.Logf("Chat response: %s", full)
t.Logf("Chat response: %s", fullText)
}
func TestROCm_ContextCancellation(t *testing.T) {
func TestROCm_ContextCancellation_Ugly(t *testing.T) {
skipIfNoROCm(t)
skipIfNoModel(t)
b := &rocmBackend{}
m, err := b.LoadModel(testModel, inference.WithContextLen(2048))
backend := &rocmBackend{}
model, err := backend.LoadModel(testModel, inference.WithContextLen(2048))
require.NoError(t, err)
defer m.Close()
defer model.Close()
ctx, cancel := context.WithCancel(context.Background())
var count int
for tok := range m.Generate(ctx, "Write a very long story about dragons", inference.WithMaxTokens(256)) {
var tokenCount int
for tok := range model.Generate(ctx, "Write a very long story about dragons", inference.WithMaxTokens(256)) {
_ = tok
count++
if count >= 3 {
tokenCount++
if tokenCount >= 3 {
cancel()
}
}
t.Logf("Got %d tokens before cancel", count)
assert.GreaterOrEqual(t, count, 3)
t.Logf("Got %d tokens before cancel", tokenCount)
assert.GreaterOrEqual(t, tokenCount, 3)
}
func TestROCm_GracefulShutdown(t *testing.T) {
func TestROCm_GracefulShutdown_Good(t *testing.T) {
skipIfNoROCm(t)
skipIfNoModel(t)
b := &rocmBackend{}
m, err := b.LoadModel(testModel, inference.WithContextLen(2048))
backend := &rocmBackend{}
model, err := backend.LoadModel(testModel, inference.WithContextLen(2048))
require.NoError(t, err)
defer m.Close()
defer model.Close()
// Cancel mid-stream.
ctx1, cancel1 := context.WithCancel(context.Background())
var count1 int
for tok := range m.Generate(ctx1, "Write a long story about space exploration", inference.WithMaxTokens(256)) {
var firstTokenCount int
for tok := range model.Generate(ctx1, "Write a long story about space exploration", inference.WithMaxTokens(256)) {
_ = tok
count1++
if count1 >= 5 {
firstTokenCount++
if firstTokenCount >= 5 {
cancel1()
}
}
t.Logf("First generation: %d tokens before cancel", count1)
t.Logf("First generation: %d tokens before cancel", firstTokenCount)
// Generate again on the same model — server should still be alive.
ctx2, cancel2 := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel2()
var count2 int
for tok := range m.Generate(ctx2, "The capital of France is", inference.WithMaxTokens(16)) {
var secondTokenCount int
for tok := range model.Generate(ctx2, "The capital of France is", inference.WithMaxTokens(16)) {
_ = tok
count2++
secondTokenCount++
}
require.NoError(t, m.Err())
assert.Greater(t, count2, 0, "expected tokens from second generation after cancel")
t.Logf("Second generation: %d tokens", count2)
require.NoError(t, model.Err())
assert.Greater(t, secondTokenCount, 0, "expected tokens from second generation after cancel")
t.Logf("Second generation: %d tokens", secondTokenCount)
}
func TestROCm_ConcurrentRequests(t *testing.T) {
func TestROCm_ConcurrentRequests_Good(t *testing.T) {
skipIfNoROCm(t)
skipIfNoModel(t)
b := &rocmBackend{}
m, err := b.LoadModel(testModel, inference.WithContextLen(2048))
backend := &rocmBackend{}
model, err := backend.LoadModel(testModel, inference.WithContextLen(2048))
require.NoError(t, err)
defer m.Close()
defer model.Close()
const numGoroutines = 3
results := make([]string, numGoroutines)
@ -175,25 +175,25 @@ func TestROCm_ConcurrentRequests(t *testing.T) {
"The capital of Italy is",
}
var wg sync.WaitGroup
wg.Add(numGoroutines)
var waitGroup sync.WaitGroup
waitGroup.Add(numGoroutines)
for i := range numGoroutines {
go func(idx int) {
defer wg.Done()
defer waitGroup.Done()
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
var sb strings.Builder
for tok := range m.Generate(ctx, prompts[idx], inference.WithMaxTokens(16)) {
sb.WriteString(tok.Text)
builder := core.NewBuilder()
for tok := range model.Generate(ctx, prompts[idx], inference.WithMaxTokens(16)) {
builder.WriteString(tok.Text)
}
results[idx] = sb.String()
results[idx] = builder.String()
}(i)
}
wg.Wait()
waitGroup.Wait()
for i, result := range results {
t.Logf("Goroutine %d: %s", i, result)
@ -201,14 +201,14 @@ func TestROCm_ConcurrentRequests(t *testing.T) {
}
}
func TestROCm_Classify(t *testing.T) {
func TestROCm_Classify_Good(t *testing.T) {
skipIfNoROCm(t)
skipIfNoModel(t)
b := &rocmBackend{}
m, err := b.LoadModel(testModel, inference.WithContextLen(2048))
backend := &rocmBackend{}
model, err := backend.LoadModel(testModel, inference.WithContextLen(2048))
require.NoError(t, err)
defer m.Close()
defer model.Close()
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
@ -218,24 +218,24 @@ func TestROCm_Classify(t *testing.T) {
"2 + 2 =",
}
results, err := m.Classify(ctx, prompts)
results, err := model.Classify(ctx, prompts)
require.NoError(t, err)
require.Len(t, results, 2)
for i, r := range results {
assert.NotEmpty(t, r.Token.Text, "classify result %d should have a token", i)
t.Logf("Classify %d: %q", i, r.Token.Text)
for i, classifyResult := range results {
assert.NotEmpty(t, classifyResult.Token.Text, "classify result %d should have a token", i)
t.Logf("Classify %d: %q", i, classifyResult.Token.Text)
}
}
func TestROCm_BatchGenerate(t *testing.T) {
func TestROCm_BatchGenerate_Good(t *testing.T) {
skipIfNoROCm(t)
skipIfNoModel(t)
b := &rocmBackend{}
m, err := b.LoadModel(testModel, inference.WithContextLen(2048))
backend := &rocmBackend{}
model, err := backend.LoadModel(testModel, inference.WithContextLen(2048))
require.NoError(t, err)
defer m.Close()
defer model.Close()
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
defer cancel()
@ -245,70 +245,70 @@ func TestROCm_BatchGenerate(t *testing.T) {
"The capital of Germany is",
}
results, err := m.BatchGenerate(ctx, prompts, inference.WithMaxTokens(8))
results, err := model.BatchGenerate(ctx, prompts, inference.WithMaxTokens(8))
require.NoError(t, err)
require.Len(t, results, 2)
for i, r := range results {
require.NoError(t, r.Err, "batch result %d error", i)
assert.NotEmpty(t, r.Tokens, "batch result %d should have tokens", i)
for i, batchResult := range results {
require.NoError(t, batchResult.Err, "batch result %d error", i)
assert.NotEmpty(t, batchResult.Tokens, "batch result %d should have tokens", i)
var sb strings.Builder
for _, tok := range r.Tokens {
sb.WriteString(tok.Text)
builder := core.NewBuilder()
for _, tok := range batchResult.Tokens {
builder.WriteString(tok.Text)
}
t.Logf("Batch %d: %s", i, sb.String())
t.Logf("Batch %d: %s", i, builder.String())
}
}
func TestROCm_InfoAndMetrics(t *testing.T) {
func TestROCm_InfoAndMetrics_Good(t *testing.T) {
skipIfNoROCm(t)
skipIfNoModel(t)
b := &rocmBackend{}
m, err := b.LoadModel(testModel, inference.WithContextLen(2048))
backend := &rocmBackend{}
model, err := backend.LoadModel(testModel, inference.WithContextLen(2048))
require.NoError(t, err)
defer m.Close()
defer model.Close()
// Info should be populated from GGUF metadata.
info := m.Info()
assert.Equal(t, "gemma3", info.Architecture)
assert.Greater(t, info.NumLayers, 0, "expected non-zero layer count")
assert.Greater(t, info.QuantBits, 0, "expected non-zero quant bits")
modelInfo := model.Info()
assert.Equal(t, "gemma3", modelInfo.Architecture)
assert.Greater(t, modelInfo.NumLayers, 0, "expected non-zero layer count")
assert.Greater(t, modelInfo.QuantBits, 0, "expected non-zero quant bits")
t.Logf("Info: arch=%s layers=%d quant=%d-bit group=%d",
info.Architecture, info.NumLayers, info.QuantBits, info.QuantGroup)
modelInfo.Architecture, modelInfo.NumLayers, modelInfo.QuantBits, modelInfo.QuantGroup)
// Generate some tokens to populate metrics.
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
for range m.Generate(ctx, "Hello", inference.WithMaxTokens(4)) {
for range model.Generate(ctx, "Hello", inference.WithMaxTokens(4)) {
}
require.NoError(t, m.Err())
require.NoError(t, model.Err())
met := m.Metrics()
assert.Greater(t, met.GeneratedTokens, 0, "expected generated tokens")
assert.Greater(t, met.TotalDuration, time.Duration(0), "expected non-zero duration")
assert.Greater(t, met.DecodeTokensPerSec, float64(0), "expected non-zero decode throughput")
metrics := model.Metrics()
assert.Greater(t, metrics.GeneratedTokens, 0, "expected generated tokens")
assert.Greater(t, metrics.TotalDuration, time.Duration(0), "expected non-zero duration")
assert.Greater(t, metrics.DecodeTokensPerSec, float64(0), "expected non-zero decode throughput")
t.Logf("Metrics: gen=%d tok, total=%s, decode=%.1f tok/s, vram=%d MiB",
met.GeneratedTokens, met.TotalDuration, met.DecodeTokensPerSec,
met.ActiveMemoryBytes/(1024*1024))
metrics.GeneratedTokens, metrics.TotalDuration, metrics.DecodeTokensPerSec,
metrics.ActiveMemoryBytes/(1024*1024))
}
func TestROCm_DiscoverModels(t *testing.T) {
dir := filepath.Dir(testModel)
if _, err := os.Stat(dir); err != nil {
func TestROCm_DiscoverModels_Good(t *testing.T) {
modelDir := core.PathDir(testModel)
if _, err := os.Stat(modelDir); err != nil {
t.Skip("model directory not available")
}
models, err := DiscoverModels(dir)
models, err := DiscoverModels(modelDir)
require.NoError(t, err)
require.NotEmpty(t, models, "expected at least one model in %s", dir)
require.NotEmpty(t, models, "expected at least one model in %s", modelDir)
for _, m := range models {
t.Logf("Found: %s (%s %s %s, ctx=%d)", filepath.Base(m.Path), m.Architecture, m.Parameters, m.Quantisation, m.ContextLen)
assert.NotEmpty(t, m.Architecture)
assert.NotEmpty(t, m.Name)
assert.Greater(t, m.FileSize, int64(0))
for _, discoveredModel := range models {
t.Logf("Found: %s (%s %s %s, ctx=%d)", core.PathBase(discoveredModel.Path), discoveredModel.Architecture, discoveredModel.Parameters, discoveredModel.Quantisation, discoveredModel.ContextLen)
assert.NotEmpty(t, discoveredModel.Architecture)
assert.NotEmpty(t, discoveredModel.Name)
assert.Greater(t, discoveredModel.FileSize, int64(0))
}
}

View file

@ -4,8 +4,12 @@ package rocm
import coreerr "forge.lthn.ai/core/go-log"
// ROCmAvailable reports whether ROCm GPU inference is available.
// ROCmAvailable reports whether ROCm GPU inference is available on this machine.
// Returns false on non-Linux or non-amd64 platforms.
//
// if rocm.ROCmAvailable() {
// m, err := inference.LoadModel("/data/model.gguf")
// }
func ROCmAvailable() bool { return false }
// GetVRAMInfo is not available on non-Linux/non-amd64 platforms.

View file

@ -4,15 +4,15 @@ package rocm
import (
"context"
"fmt"
"net"
"os"
"os/exec"
"strconv"
"strings"
"syscall"
"time"
"dappco.re/go/core"
coreerr "forge.lthn.ai/core/go-log"
"forge.lthn.ai/core/go-rocm/internal/llamacpp"
)
@ -38,28 +38,31 @@ func (s *server) alive() bool {
// findLlamaServer locates the llama-server binary.
// Checks ROCM_LLAMA_SERVER_PATH first, then PATH.
//
// path, err := findLlamaServer()
// // path == "/usr/local/bin/llama-server"
func findLlamaServer() (string, error) {
if p := os.Getenv("ROCM_LLAMA_SERVER_PATH"); p != "" {
if _, err := os.Stat(p); err != nil {
return "", coreerr.E("rocm.findLlamaServer", "llama-server not found at ROCM_LLAMA_SERVER_PATH="+p, err)
if binaryPath := core.Env("ROCM_LLAMA_SERVER_PATH"); binaryPath != "" {
if !(&core.Fs{}).New("/").Exists(binaryPath) {
return "", coreerr.E("rocm.findLlamaServer", "llama-server not found at ROCM_LLAMA_SERVER_PATH="+binaryPath, nil)
}
return p, nil
return binaryPath, nil
}
p, err := exec.LookPath("llama-server")
binaryPath, err := exec.LookPath("llama-server")
if err != nil {
return "", coreerr.E("rocm.findLlamaServer", "llama-server not found in PATH", err)
}
return p, nil
return binaryPath, nil
}
// freePort asks the kernel for a free TCP port on localhost.
func freePort() (int, error) {
ln, err := net.Listen("tcp", "127.0.0.1:0")
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
return 0, coreerr.E("rocm.freePort", "listen for free port", err)
}
port := ln.Addr().(*net.TCPAddr).Port
ln.Close()
port := listener.Addr().(*net.TCPAddr).Port
listener.Close()
return port, nil
}
@ -69,11 +72,11 @@ func freePort() (int, error) {
func serverEnv() []string {
environ := os.Environ()
env := make([]string, 0, len(environ)+1)
for _, e := range environ {
if strings.HasPrefix(e, "HIP_VISIBLE_DEVICES=") {
for _, envEntry := range environ {
if core.HasPrefix(envEntry, "HIP_VISIBLE_DEVICES=") {
continue
}
env = append(env, e)
env = append(env, envEntry)
}
env = append(env, "HIP_VISIBLE_DEVICES=0")
return env
@ -82,7 +85,10 @@ func serverEnv() []string {
// startServer spawns llama-server and waits for it to become ready.
// It selects a free port automatically, retrying up to 3 times if the
// process exits during startup (e.g. port conflict).
func startServer(binary, modelPath string, gpuLayers, ctxSize, parallelSlots int) (*server, error) {
//
// s, err := startServer("/usr/local/bin/llama-server", "/data/model.gguf", 99, 4096, 4)
// defer s.stop()
func startServer(binary, modelPath string, gpuLayers, contextSize, parallelSlots int) (*server, error) {
if gpuLayers < 0 {
gpuLayers = 999
}
@ -102,8 +108,8 @@ func startServer(binary, modelPath string, gpuLayers, ctxSize, parallelSlots int
"--port", strconv.Itoa(port),
"--n-gpu-layers", strconv.Itoa(gpuLayers),
}
if ctxSize > 0 {
args = append(args, "--ctx-size", strconv.Itoa(ctxSize))
if contextSize > 0 {
args = append(args, "--ctx-size", strconv.Itoa(contextSize))
}
if parallelSlots > 0 {
args = append(args, "--parallel", strconv.Itoa(parallelSlots))
@ -116,39 +122,39 @@ func startServer(binary, modelPath string, gpuLayers, ctxSize, parallelSlots int
return nil, coreerr.E("rocm.startServer", "start llama-server", err)
}
s := &server{
subprocess := &server{
cmd: cmd,
port: port,
client: llamacpp.NewClient(fmt.Sprintf("http://127.0.0.1:%d", port)),
client: llamacpp.NewClient(core.Sprintf("http://127.0.0.1:%d", port)),
exited: make(chan struct{}),
}
go func() {
s.exitErr = cmd.Wait()
close(s.exited)
subprocess.exitErr = cmd.Wait()
close(subprocess.exited)
}()
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
err = s.waitReady(ctx)
err = subprocess.waitReady(ctx)
cancel()
if err == nil {
return s, nil
return subprocess, nil
}
// Only retry if the process actually exited (e.g. port conflict).
// A timeout means the server is stuck, not a port issue.
select {
case <-s.exited:
_ = s.stop()
lastErr = coreerr.E("rocm.startServer", fmt.Sprintf("attempt %d", attempt+1), err)
case <-subprocess.exited:
_ = subprocess.stop()
lastErr = coreerr.E("rocm.startServer", core.Sprintf("attempt %d", attempt+1), err)
continue
default:
_ = s.stop()
_ = subprocess.stop()
return nil, coreerr.E("rocm.startServer", "llama-server not ready", err)
}
}
return nil, coreerr.E("rocm.startServer", fmt.Sprintf("server failed after %d attempts", maxAttempts), lastErr)
return nil, coreerr.E("rocm.startServer", core.Sprintf("server failed after %d attempts", maxAttempts), lastErr)
}
// waitReady polls the health endpoint until the server is ready.

View file

@ -4,44 +4,52 @@ package rocm
import (
"context"
"os"
"strings"
"testing"
"dappco.re/go/core"
"forge.lthn.ai/core/go-inference"
coreerr "forge.lthn.ai/core/go-log"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestFindLlamaServer_InPATH(t *testing.T) {
func TestFindLlamaServer_Good_InPATH(t *testing.T) {
// llama-server is at /usr/local/bin/llama-server on this machine.
path, err := findLlamaServer()
require.NoError(t, err)
assert.Contains(t, path, "llama-server")
}
func TestFindLlamaServer_EnvOverride(t *testing.T) {
func TestFindLlamaServer_Good_EnvOverride(t *testing.T) {
t.Setenv("ROCM_LLAMA_SERVER_PATH", "/usr/local/bin/llama-server")
path, err := findLlamaServer()
require.NoError(t, err)
assert.Equal(t, "/usr/local/bin/llama-server", path)
}
func TestFindLlamaServer_EnvNotFound(t *testing.T) {
func TestFindLlamaServer_Bad_EnvPathMissing(t *testing.T) {
t.Setenv("ROCM_LLAMA_SERVER_PATH", "/nonexistent/llama-server")
_, err := findLlamaServer()
assert.ErrorContains(t, err, "not found")
}
func TestFreePort(t *testing.T) {
func TestFindLlamaServer_Ugly_EmptyPATH(t *testing.T) {
// With no ROCM_LLAMA_SERVER_PATH set and an empty PATH, LookPath must fail.
t.Setenv("ROCM_LLAMA_SERVER_PATH", "")
t.Setenv("PATH", "")
_, err := findLlamaServer()
assert.Error(t, err)
}
func TestFreePort_Good(t *testing.T) {
port, err := freePort()
require.NoError(t, err)
assert.Greater(t, port, 0)
assert.Less(t, port, 65536)
}
func TestFreePort_UniquePerCall(t *testing.T) {
func TestFreePort_Good_UniquePerCall(t *testing.T) {
p1, err := freePort()
require.NoError(t, err)
p2, err := freePort()
@ -50,68 +58,180 @@ func TestFreePort_UniquePerCall(t *testing.T) {
_ = p2
}
func TestServerEnv_HIPVisibleDevices(t *testing.T) {
func TestFreePort_Bad_InvalidAddr(t *testing.T) {
// freePort always binds to localhost so it can't fail on a valid machine;
// this test documents that the port is always in the valid range.
port, err := freePort()
require.NoError(t, err)
assert.Greater(t, port, 1023, "expected unprivileged port")
}
func TestFreePort_Ugly_ReturnsUsablePort(t *testing.T) {
// The returned port should be bindable a second time.
port, err := freePort()
require.NoError(t, err)
assert.NotZero(t, port)
}
func TestServerEnv_Good_SetsHIPVisibleDevices(t *testing.T) {
env := serverEnv()
var hipVals []string
for _, e := range env {
if strings.HasPrefix(e, "HIP_VISIBLE_DEVICES=") {
hipVals = append(hipVals, e)
for _, envEntry := range env {
if core.HasPrefix(envEntry, "HIP_VISIBLE_DEVICES=") {
hipVals = append(hipVals, envEntry)
}
}
assert.Equal(t, []string{"HIP_VISIBLE_DEVICES=0"}, hipVals)
}
func TestServerEnv_FiltersExistingHIP(t *testing.T) {
func TestServerEnv_Good_FiltersExistingHIPVisibleDevices(t *testing.T) {
t.Setenv("HIP_VISIBLE_DEVICES", "1")
env := serverEnv()
var hipVals []string
for _, e := range env {
if strings.HasPrefix(e, "HIP_VISIBLE_DEVICES=") {
hipVals = append(hipVals, e)
for _, envEntry := range env {
if core.HasPrefix(envEntry, "HIP_VISIBLE_DEVICES=") {
hipVals = append(hipVals, envEntry)
}
}
assert.Equal(t, []string{"HIP_VISIBLE_DEVICES=0"}, hipVals)
}
func TestAvailable(t *testing.T) {
func TestServerEnv_Bad_NilEnviron(t *testing.T) {
// serverEnv must never panic even when called with unusual env state.
// It always appends HIP_VISIBLE_DEVICES=0 regardless of ambient env.
env := serverEnv()
assert.NotEmpty(t, env)
}
func TestServerEnv_Ugly_MultipleHIPEntries(t *testing.T) {
// Even if multiple HIP_VISIBLE_DEVICES entries somehow existed, only one must remain.
t.Setenv("HIP_VISIBLE_DEVICES", "2,3")
env := serverEnv()
var hipVals []string
for _, envEntry := range env {
if core.HasPrefix(envEntry, "HIP_VISIBLE_DEVICES=") {
hipVals = append(hipVals, envEntry)
}
}
assert.Equal(t, []string{"HIP_VISIBLE_DEVICES=0"}, hipVals)
}
func TestAvailable_Good(t *testing.T) {
b := &rocmBackend{}
if _, err := os.Stat("/dev/kfd"); err != nil {
if !(&core.Fs{}).New("/").Exists("/dev/kfd") {
t.Skip("no ROCm hardware")
}
assert.True(t, b.Available())
}
func TestAvailable_Bad_NoDevice(t *testing.T) {
// When /dev/kfd is absent, Available must return false.
if (&core.Fs{}).New("/").Exists("/dev/kfd") {
t.Skip("ROCm device present — skip no-device bad path on this machine")
}
b := &rocmBackend{}
assert.False(t, b.Available())
}
func TestServerAlive_Running(t *testing.T) {
func TestAvailable_Ugly_NoLlamaServer(t *testing.T) {
// Even with /dev/kfd present, Available must be false if llama-server is missing.
// We can't create /dev/kfd in a test, so verify the condition via findLlamaServer.
t.Setenv("PATH", "")
t.Setenv("ROCM_LLAMA_SERVER_PATH", "")
b := &rocmBackend{}
// If kfd is present but llama-server missing, Available returns false.
// If kfd is absent, Available also returns false. Either way, not panic.
_ = b.Available()
}
func TestServerAlive_Good_Running(t *testing.T) {
s := &server{exited: make(chan struct{})}
assert.True(t, s.alive())
}
func TestServerAlive_Exited(t *testing.T) {
func TestServerAlive_Good_Exited(t *testing.T) {
exited := make(chan struct{})
close(exited)
s := &server{exited: exited, exitErr: coreerr.E("test", "process killed", nil)}
assert.False(t, s.alive())
}
func TestGenerate_ServerDead(t *testing.T) {
func TestServerAlive_Bad_NilExited(t *testing.T) {
// A server with a nil exited channel panics — this documents the contract:
// exited must always be initialised before use.
// We test the well-formed bad state: exited closed with nil exitErr.
exited := make(chan struct{})
close(exited)
s := &server{exited: exited, exitErr: nil}
assert.False(t, s.alive())
}
func TestServerAlive_Ugly_ExitedAfterStart(t *testing.T) {
// alive transitions from true to false when the channel is closed.
exited := make(chan struct{})
s := &server{exited: exited}
assert.True(t, s.alive())
close(exited)
assert.False(t, s.alive())
}
func TestGenerate_Good_YieldsNoTokensOnEmptyServer(t *testing.T) {
// A newly-created dead server produces zero tokens and records an error.
exited := make(chan struct{})
close(exited)
s := &server{exited: exited, exitErr: nil}
m := &rocmModel{server: s}
var tokenCount int
for range m.Generate(context.Background(), "hello") {
tokenCount++
}
assert.Equal(t, 0, tokenCount)
}
func TestGenerate_Bad_ServerDead(t *testing.T) {
exited := make(chan struct{})
close(exited)
s := &server{
exited: exited,
exitErr: coreerr.E("test", "process killed", nil),
}
m := &rocmModel{srv: s}
m := &rocmModel{server: s}
var count int
var tokenCount int
for range m.Generate(context.Background(), "hello") {
count++
tokenCount++
}
assert.Equal(t, 0, count)
assert.Equal(t, 0, tokenCount)
assert.ErrorContains(t, m.Err(), "server has exited")
}
func TestStartServer_RetriesOnProcessExit(t *testing.T) {
func TestGenerate_Ugly_ErrClearedBetweenCalls(t *testing.T) {
// Err is reset to nil on each Generate call start.
exited := make(chan struct{})
close(exited)
s := &server{exited: exited, exitErr: coreerr.E("test", "killed", nil)}
m := &rocmModel{server: s}
for range m.Generate(context.Background(), "first") {
}
assert.Error(t, m.Err())
}
func TestStartServer_Good_RejectsNegativeLayers(t *testing.T) {
// gpuLayers=-1 must be converted to 999 (all layers on GPU).
// /bin/false exits immediately so we observe the retry behaviour.
_, err := startServer("/bin/false", "/nonexistent/model.gguf", -1, 0, 0)
require.Error(t, err)
assert.Contains(t, err.Error(), "failed after 3 attempts")
}
func TestStartServer_Bad_BinaryNotFound(t *testing.T) {
_, err := startServer("/nonexistent/binary", "/nonexistent/model.gguf", 0, 0, 0)
require.Error(t, err)
}
func TestStartServer_Ugly_RetriesOnProcessExit(t *testing.T) {
// /bin/false starts successfully but exits immediately with code 1.
// startServer should retry up to 3 times, then fail.
_, err := startServer("/bin/false", "/nonexistent/model.gguf", 999, 0, 0)
@ -119,20 +239,53 @@ func TestStartServer_RetriesOnProcessExit(t *testing.T) {
assert.Contains(t, err.Error(), "failed after 3 attempts")
}
func TestChat_ServerDead(t *testing.T) {
func TestChat_Good_EmptyMessages(t *testing.T) {
// Chat with an empty message list on a dead server yields no tokens.
exited := make(chan struct{})
close(exited)
s := &server{exited: exited, exitErr: nil}
m := &rocmModel{server: s}
var tokenCount int
for range m.Chat(context.Background(), nil) {
tokenCount++
}
assert.Equal(t, 0, tokenCount)
}
func TestChat_Bad_ServerDead(t *testing.T) {
exited := make(chan struct{})
close(exited)
s := &server{
exited: exited,
exitErr: coreerr.E("test", "process killed", nil),
}
m := &rocmModel{srv: s}
m := &rocmModel{server: s}
msgs := []inference.Message{{Role: "user", Content: "hello"}}
var count int
var tokenCount int
for range m.Chat(context.Background(), msgs) {
count++
tokenCount++
}
assert.Equal(t, 0, count)
assert.Equal(t, 0, tokenCount)
assert.ErrorContains(t, m.Err(), "server has exited")
}
func TestChat_Ugly_MultipleRolesOnDeadServer(t *testing.T) {
// Chat with multiple roles on a dead server must still return safely.
exited := make(chan struct{})
close(exited)
s := &server{exited: exited, exitErr: coreerr.E("test", "killed", nil)}
m := &rocmModel{server: s}
msgs := []inference.Message{
{Role: "system", Content: "You are helpful."},
{Role: "user", Content: "Hello"},
}
var tokenCount int
for range m.Chat(context.Background(), msgs) {
tokenCount++
}
assert.Equal(t, 0, tokenCount)
assert.Error(t, m.Err())
}

32
vram.go
View file

@ -3,10 +3,9 @@
package rocm
import (
"os"
"path/filepath"
"strconv"
"strings"
"dappco.re/go/core"
coreerr "forge.lthn.ai/core/go-log"
)
@ -15,13 +14,10 @@ import (
// It identifies the dGPU by selecting the card with the largest VRAM total,
// which avoids hardcoding card numbers (e.g. card0=iGPU, card1=dGPU on Ryzen).
//
// Note: total and used are read non-atomically from sysfs; transient
// inconsistencies are possible under heavy allocation churn.
// info, err := rocm.GetVRAMInfo()
// // info.Total == 17179869184, info.Used == 2147483648, info.Free == 15032385536
func GetVRAMInfo() (VRAMInfo, error) {
cards, err := filepath.Glob("/sys/class/drm/card[0-9]*/device/mem_info_vram_total")
if err != nil {
return VRAMInfo{}, coreerr.E("rocm.GetVRAMInfo", "glob vram sysfs", err)
}
cards := core.PathGlob("/sys/class/drm/card[0-9]*/device/mem_info_vram_total")
if len(cards) == 0 {
return VRAMInfo{}, coreerr.E("rocm.GetVRAMInfo", "no GPU VRAM info found in sysfs", nil)
}
@ -36,7 +32,7 @@ func GetVRAMInfo() (VRAMInfo, error) {
}
if total > bestTotal {
bestTotal = total
bestDir = filepath.Dir(totalPath)
bestDir = core.PathDir(totalPath)
}
}
@ -44,27 +40,27 @@ func GetVRAMInfo() (VRAMInfo, error) {
return VRAMInfo{}, coreerr.E("rocm.GetVRAMInfo", "no readable VRAM sysfs entries", nil)
}
used, err := readSysfsUint64(filepath.Join(bestDir, "mem_info_vram_used"))
used, err := readSysfsUint64(core.Path(bestDir, "mem_info_vram_used"))
if err != nil {
return VRAMInfo{}, coreerr.E("rocm.GetVRAMInfo", "read vram used", err)
}
free := uint64(0)
freeBytes := uint64(0)
if bestTotal > used {
free = bestTotal - used
freeBytes = bestTotal - used
}
return VRAMInfo{
Total: bestTotal,
Used: used,
Free: free,
Free: freeBytes,
}, nil
}
func readSysfsUint64(path string) (uint64, error) {
data, err := os.ReadFile(path)
if err != nil {
return 0, err
result := (&core.Fs{}).New("/").Read(path)
if !result.OK {
return 0, coreerr.E("rocm.readSysfsUint64", "read sysfs file", result.Value.(error))
}
return strconv.ParseUint(strings.TrimSpace(string(data)), 10, 64)
return strconv.ParseUint(core.Trim(result.Value.(string)), 10, 64)
}

View file

@ -4,16 +4,17 @@ package rocm
import (
"os"
"path/filepath"
"testing"
"dappco.re/go/core"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestReadSysfsUint64(t *testing.T) {
func TestReadSysfsUint64_Good(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "test_value")
path := core.Path(dir, "test_value")
require.NoError(t, os.WriteFile(path, []byte("17163091968\n"), 0644))
val, err := readSysfsUint64(path)
@ -21,30 +22,41 @@ func TestReadSysfsUint64(t *testing.T) {
assert.Equal(t, uint64(17163091968), val)
}
func TestReadSysfsUint64_NotFound(t *testing.T) {
func TestReadSysfsUint64_Bad_NotFound(t *testing.T) {
_, err := readSysfsUint64("/nonexistent/path")
assert.Error(t, err)
}
func TestReadSysfsUint64_InvalidContent(t *testing.T) {
func TestReadSysfsUint64_Bad_InvalidContent(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "bad_value")
path := core.Path(dir, "bad_value")
require.NoError(t, os.WriteFile(path, []byte("not-a-number\n"), 0644))
_, err := readSysfsUint64(path)
assert.Error(t, err)
}
func TestReadSysfsUint64_EmptyFile(t *testing.T) {
func TestReadSysfsUint64_Bad_EmptyFile(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "empty_value")
path := core.Path(dir, "empty_value")
require.NoError(t, os.WriteFile(path, []byte(""), 0644))
_, err := readSysfsUint64(path)
assert.Error(t, err)
}
func TestGetVRAMInfo(t *testing.T) {
func TestReadSysfsUint64_Ugly_WhitespaceAroundValue(t *testing.T) {
// sysfs files often contain a trailing newline; readSysfsUint64 must trim it.
dir := t.TempDir()
path := core.Path(dir, "whitespace_value")
require.NoError(t, os.WriteFile(path, []byte(" 17163091968 \n"), 0644))
val, err := readSysfsUint64(path)
require.NoError(t, err)
assert.Equal(t, uint64(17163091968), val)
}
func TestGetVRAMInfo_Good(t *testing.T) {
info, err := GetVRAMInfo()
if err != nil {
t.Skipf("no VRAM sysfs info available: %v", err)
@ -55,3 +67,23 @@ func TestGetVRAMInfo(t *testing.T) {
assert.Greater(t, info.Used, uint64(0), "expected some VRAM in use")
assert.Equal(t, info.Total-info.Used, info.Free, "Free should equal Total-Used")
}
func TestGetVRAMInfo_Bad_NoGPU(t *testing.T) {
// On non-ROCm machines GetVRAMInfo must return an error, not panic.
_, err := GetVRAMInfo()
if err == nil {
t.Skip("ROCm sysfs present — not applicable on this machine")
}
assert.Error(t, err)
}
func TestGetVRAMInfo_Ugly_FreeComputed(t *testing.T) {
// Invariant: Free == Total - Used whenever Total >= Used.
info, err := GetVRAMInfo()
if err != nil {
t.Skipf("no VRAM sysfs info available: %v", err)
}
if info.Total >= info.Used {
assert.Equal(t, info.Total-info.Used, info.Free)
}
}