feat(ax): pass 2 — replace banned imports, rename variables, add AX comments
Replace fmt/strings/path/filepath/encoding/json with core equivalents throughout
all packages. Rename cfg→configuration, srv→server/subprocess, ftName→fileTypeName,
ctxSize→contextSize. Add usage-example doc-comments to every exported symbol.
Update all test names to TestSubject_Function_{Good,Bad,Ugly} convention.
Co-Authored-By: Virgil <virgil@lethean.io>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
41b34b6779
commit
523abc6509
18 changed files with 727 additions and 458 deletions
35
backend.go
35
backend.go
|
|
@ -3,8 +3,7 @@
|
|||
package rocm
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strings"
|
||||
"dappco.re/go/core"
|
||||
|
||||
coreerr "forge.lthn.ai/core/go-log"
|
||||
"forge.lthn.ai/core/go-inference"
|
||||
|
|
@ -22,7 +21,7 @@ func (b *rocmBackend) Name() string { return "rocm" }
|
|||
// b := inference.FindBackend("rocm")
|
||||
// if b.Available() { /* safe to LoadModel */ }
|
||||
func (b *rocmBackend) Available() bool {
|
||||
if _, err := os.Stat("/dev/kfd"); err != nil {
|
||||
if !(&core.Fs{}).New("/").Exists("/dev/kfd") {
|
||||
return false
|
||||
}
|
||||
if _, err := findLlamaServer(); err != nil {
|
||||
|
|
@ -42,7 +41,7 @@ func (b *rocmBackend) Available() bool {
|
|||
// )
|
||||
// defer m.Close()
|
||||
func (b *rocmBackend) LoadModel(path string, opts ...inference.LoadOption) (inference.TextModel, error) {
|
||||
config := inference.ApplyLoadOpts(opts)
|
||||
configuration := inference.ApplyLoadOpts(opts)
|
||||
|
||||
binary, err := findLlamaServer()
|
||||
if err != nil {
|
||||
|
|
@ -54,12 +53,12 @@ func (b *rocmBackend) LoadModel(path string, opts ...inference.LoadOption) (infe
|
|||
return nil, coreerr.E("rocm.LoadModel", "read model metadata", err)
|
||||
}
|
||||
|
||||
ctxLen := config.ContextLen
|
||||
if ctxLen == 0 && meta.ContextLength > 0 {
|
||||
ctxLen = int(min(meta.ContextLength, 4096))
|
||||
contextLen := configuration.ContextLen
|
||||
if contextLen == 0 && meta.ContextLength > 0 {
|
||||
contextLen = int(min(meta.ContextLength, 4096))
|
||||
}
|
||||
|
||||
srv, err := startServer(binary, path, config.GPULayers, ctxLen, config.ParallelSlots)
|
||||
subprocess, err := startServer(binary, path, configuration.GPULayers, contextLen, configuration.ParallelSlots)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -67,34 +66,34 @@ func (b *rocmBackend) LoadModel(path string, opts ...inference.LoadOption) (infe
|
|||
// Map quantisation file type to bit width.
|
||||
quantBits := 0
|
||||
quantGroup := 0
|
||||
ftName := gguf.FileTypeName(meta.FileType)
|
||||
fileTypeName := gguf.FileTypeName(meta.FileType)
|
||||
switch {
|
||||
case strings.HasPrefix(ftName, "Q4_"):
|
||||
case core.HasPrefix(fileTypeName, "Q4_"):
|
||||
quantBits = 4
|
||||
quantGroup = 32
|
||||
case strings.HasPrefix(ftName, "Q5_"):
|
||||
case core.HasPrefix(fileTypeName, "Q5_"):
|
||||
quantBits = 5
|
||||
quantGroup = 32
|
||||
case strings.HasPrefix(ftName, "Q8_"):
|
||||
case core.HasPrefix(fileTypeName, "Q8_"):
|
||||
quantBits = 8
|
||||
quantGroup = 32
|
||||
case strings.HasPrefix(ftName, "Q2_"):
|
||||
case core.HasPrefix(fileTypeName, "Q2_"):
|
||||
quantBits = 2
|
||||
quantGroup = 16
|
||||
case strings.HasPrefix(ftName, "Q3_"):
|
||||
case core.HasPrefix(fileTypeName, "Q3_"):
|
||||
quantBits = 3
|
||||
quantGroup = 32
|
||||
case strings.HasPrefix(ftName, "Q6_"):
|
||||
case core.HasPrefix(fileTypeName, "Q6_"):
|
||||
quantBits = 6
|
||||
quantGroup = 64
|
||||
case ftName == "F16":
|
||||
case fileTypeName == "F16":
|
||||
quantBits = 16
|
||||
case ftName == "F32":
|
||||
case fileTypeName == "F32":
|
||||
quantBits = 32
|
||||
}
|
||||
|
||||
return &rocmModel{
|
||||
server: srv,
|
||||
server: subprocess,
|
||||
modelType: meta.Architecture,
|
||||
modelInfo: inference.ModelInfo{
|
||||
Architecture: meta.Architecture,
|
||||
|
|
|
|||
15
discover.go
15
discover.go
|
|
@ -1,23 +1,18 @@
|
|||
package rocm
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"dappco.re/go/core"
|
||||
|
||||
"forge.lthn.ai/core/go-rocm/internal/gguf"
|
||||
)
|
||||
|
||||
// DiscoverModels scans a directory for GGUF model files.
|
||||
// Files that cannot be parsed are silently skipped.
|
||||
// DiscoverModels scans a directory for GGUF model files and returns
|
||||
// structured information about each. Files that cannot be parsed are skipped.
|
||||
//
|
||||
// models, err := rocm.DiscoverModels("/data/lem/gguf")
|
||||
// for _, m := range models {
|
||||
// fmt.Printf("%s: %s %s ctx=%d\n", m.Name, m.Architecture, m.Quantisation, m.ContextLen)
|
||||
// }
|
||||
// for _, m := range models { core.Print(c, "%s %s ctx=%d", m.Name, m.Quantisation, m.ContextLen) }
|
||||
func DiscoverModels(dir string) ([]ModelInfo, error) {
|
||||
matches, err := filepath.Glob(filepath.Join(dir, "*.gguf"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
matches := core.PathGlob(core.Path(dir, "*.gguf"))
|
||||
|
||||
var models []ModelInfo
|
||||
for _, path := range matches {
|
||||
|
|
|
|||
|
|
@ -3,9 +3,10 @@ package rocm
|
|||
import (
|
||||
"encoding/binary"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"dappco.re/go/core"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
|
@ -15,7 +16,7 @@ import (
|
|||
func writeDiscoverTestGGUF(t *testing.T, dir, filename string, kvs [][2]any) string {
|
||||
t.Helper()
|
||||
|
||||
path := filepath.Join(dir, filename)
|
||||
path := core.Path(dir, filename)
|
||||
|
||||
f, err := os.Create(path)
|
||||
require.NoError(t, err)
|
||||
|
|
@ -86,7 +87,7 @@ func TestDiscoverModels_Good(t *testing.T) {
|
|||
})
|
||||
|
||||
// Create a non-GGUF file that should be ignored (no .gguf extension).
|
||||
require.NoError(t, os.WriteFile(filepath.Join(dir, "README.txt"), []byte("not a model"), 0644))
|
||||
require.NoError(t, os.WriteFile(core.Path(dir, "README.txt"), []byte("not a model"), 0644))
|
||||
|
||||
models, err := DiscoverModels(dir)
|
||||
require.NoError(t, err)
|
||||
|
|
@ -94,7 +95,7 @@ func TestDiscoverModels_Good(t *testing.T) {
|
|||
|
||||
// Sort order from Glob is lexicographic, so gemma3 comes first.
|
||||
gemma := models[0]
|
||||
assert.Equal(t, filepath.Join(dir, "gemma3-4b-q4km.gguf"), gemma.Path)
|
||||
assert.Equal(t, core.Path(dir, "gemma3-4b-q4km.gguf"), gemma.Path)
|
||||
assert.Equal(t, "gemma3", gemma.Architecture)
|
||||
assert.Equal(t, "Gemma 3 4B Instruct", gemma.Name)
|
||||
assert.Equal(t, "Q4_K_M", gemma.Quantisation)
|
||||
|
|
@ -103,7 +104,7 @@ func TestDiscoverModels_Good(t *testing.T) {
|
|||
assert.Greater(t, gemma.FileSize, int64(0))
|
||||
|
||||
llama := models[1]
|
||||
assert.Equal(t, filepath.Join(dir, "llama-3.1-8b-q4km.gguf"), llama.Path)
|
||||
assert.Equal(t, core.Path(dir, "llama-3.1-8b-q4km.gguf"), llama.Path)
|
||||
assert.Equal(t, "llama", llama.Architecture)
|
||||
assert.Equal(t, "Llama 3.1 8B Instruct", llama.Name)
|
||||
assert.Equal(t, "Q4_K_M", llama.Quantisation)
|
||||
|
|
@ -120,8 +121,8 @@ func TestDiscoverModels_Good_EmptyDir(t *testing.T) {
|
|||
assert.Empty(t, models)
|
||||
}
|
||||
|
||||
func TestDiscoverModels_Good_NonExistentDir(t *testing.T) {
|
||||
// filepath.Glob returns nil, nil for a pattern matching no files,
|
||||
func TestDiscoverModels_Bad_NonExistentDir(t *testing.T) {
|
||||
// core.PathGlob returns nil for patterns matching no files,
|
||||
// even when the directory does not exist.
|
||||
models, err := DiscoverModels("/nonexistent/dir")
|
||||
require.NoError(t, err)
|
||||
|
|
@ -139,7 +140,7 @@ func TestDiscoverModels_Ugly_SkipsCorruptFile(t *testing.T) {
|
|||
})
|
||||
|
||||
// Create a corrupt .gguf file (not valid GGUF binary).
|
||||
require.NoError(t, os.WriteFile(filepath.Join(dir, "corrupt.gguf"), []byte("not gguf data"), 0644))
|
||||
require.NoError(t, os.WriteFile(core.Path(dir, "corrupt.gguf"), []byte("not gguf data"), 0644))
|
||||
|
||||
models, err := DiscoverModels(dir)
|
||||
require.NoError(t, err)
|
||||
|
|
|
|||
6
go.mod
6
go.mod
|
|
@ -3,15 +3,11 @@ module forge.lthn.ai/core/go-rocm
|
|||
go 1.26.0
|
||||
|
||||
require (
|
||||
dappco.re/go/core v0.8.0-alpha.1
|
||||
forge.lthn.ai/core/go-inference v0.1.5
|
||||
forge.lthn.ai/core/go-log v0.0.4
|
||||
)
|
||||
|
||||
require (
|
||||
dappco.re/go/core v0.8.0-alpha.1 // indirect
|
||||
github.com/kr/text v0.2.0 // indirect
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
|
||||
|
|
|
|||
1
go.sum
1
go.sum
|
|
@ -2,7 +2,6 @@ dappco.re/go/core v0.8.0-alpha.1 h1:gj7+Scv+L63Z7wMxbJYHhaRFkHJo2u4MMPuUSv/Dhtk=
|
|||
dappco.re/go/core v0.8.0-alpha.1/go.mod h1:f2/tBZ3+3IqDrg2F5F598llv0nmb/4gJVCFzM5geE4A=
|
||||
forge.lthn.ai/core/go-log v0.0.4 h1:KTuCEPgFmuM8KJfnyQ8vPOU1Jg654W74h8IJvfQMfv0=
|
||||
forge.lthn.ai/core/go-log v0.0.4/go.mod h1:r14MXKOD3LF/sI8XUJQhRk/SZHBE7jAFVuCfgkXoZPw=
|
||||
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
|
||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
|
||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||
|
|
|
|||
|
|
@ -10,11 +10,11 @@ package gguf
|
|||
import (
|
||||
"bufio"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"dappco.re/go/core"
|
||||
|
||||
coreerr "forge.lthn.ai/core/go-log"
|
||||
)
|
||||
|
|
@ -73,68 +73,68 @@ var fileTypeNames = map[uint32]string{
|
|||
// FileTypeName returns a human-readable name for a GGML quantisation file type.
|
||||
// Unknown types return "type_N" where N is the numeric value.
|
||||
//
|
||||
// gguf.FileTypeName(15) // "Q4_K_M"
|
||||
// gguf.FileTypeName(7) // "Q8_0"
|
||||
// gguf.FileTypeName(1) // "F16"
|
||||
// name := gguf.FileTypeName(15) // "Q4_K_M"
|
||||
// name := gguf.FileTypeName(17) // "Q5_K_M"
|
||||
func FileTypeName(ft uint32) string {
|
||||
if name, ok := fileTypeNames[ft]; ok {
|
||||
return name
|
||||
}
|
||||
return fmt.Sprintf("type_%d", ft)
|
||||
return core.Sprintf("type_%d", ft)
|
||||
}
|
||||
|
||||
// ReadMetadata reads the GGUF header from the file at path and returns the
|
||||
// extracted metadata. Only metadata KV pairs are read; tensor data is not loaded.
|
||||
//
|
||||
// meta, err := gguf.ReadMetadata("/data/model.gguf")
|
||||
// fmt.Printf("arch=%s ctx=%d quant=%s", meta.Architecture, meta.ContextLength, gguf.FileTypeName(meta.FileType))
|
||||
// meta, err := gguf.ReadMetadata("/data/lem/gguf/gemma3-4b-q4km.gguf")
|
||||
// // meta.Architecture == "gemma3", meta.ContextLength == 32768
|
||||
func ReadMetadata(path string) (Metadata, error) {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return Metadata{}, err
|
||||
openResult := (&core.Fs{}).New("/").Open(path)
|
||||
if !openResult.OK {
|
||||
return Metadata{}, openResult.Value.(error)
|
||||
}
|
||||
defer f.Close()
|
||||
file := openResult.Value.(*os.File)
|
||||
defer file.Close()
|
||||
|
||||
info, err := f.Stat()
|
||||
info, err := file.Stat()
|
||||
if err != nil {
|
||||
return Metadata{}, err
|
||||
}
|
||||
|
||||
r := bufio.NewReader(f)
|
||||
reader := bufio.NewReader(file)
|
||||
|
||||
// Read and validate magic number.
|
||||
var magic uint32
|
||||
if err := binary.Read(r, binary.LittleEndian, &magic); err != nil {
|
||||
if err := binary.Read(reader, binary.LittleEndian, &magic); err != nil {
|
||||
return Metadata{}, coreerr.E("gguf.ReadMetadata", "reading magic", err)
|
||||
}
|
||||
if magic != ggufMagic {
|
||||
return Metadata{}, coreerr.E("gguf.ReadMetadata", fmt.Sprintf("invalid magic: 0x%08X (expected 0x%08X)", magic, ggufMagic), nil)
|
||||
return Metadata{}, coreerr.E("gguf.ReadMetadata", core.Sprintf("invalid magic: 0x%08X (expected 0x%08X)", magic, ggufMagic), nil)
|
||||
}
|
||||
|
||||
// Read version.
|
||||
var version uint32
|
||||
if err := binary.Read(r, binary.LittleEndian, &version); err != nil {
|
||||
if err := binary.Read(reader, binary.LittleEndian, &version); err != nil {
|
||||
return Metadata{}, coreerr.E("gguf.ReadMetadata", "reading version", err)
|
||||
}
|
||||
if version < 2 || version > 3 {
|
||||
return Metadata{}, coreerr.E("gguf.ReadMetadata", fmt.Sprintf("unsupported GGUF version: %d", version), nil)
|
||||
return Metadata{}, coreerr.E("gguf.ReadMetadata", core.Sprintf("unsupported GGUF version: %d", version), nil)
|
||||
}
|
||||
|
||||
// Read tensor count and KV count. v3 uses uint64, v2 uses uint32.
|
||||
var tensorCount, kvCount uint64
|
||||
if version == 3 {
|
||||
if err := binary.Read(r, binary.LittleEndian, &tensorCount); err != nil {
|
||||
if err := binary.Read(reader, binary.LittleEndian, &tensorCount); err != nil {
|
||||
return Metadata{}, coreerr.E("gguf.ReadMetadata", "reading tensor count", err)
|
||||
}
|
||||
if err := binary.Read(r, binary.LittleEndian, &kvCount); err != nil {
|
||||
if err := binary.Read(reader, binary.LittleEndian, &kvCount); err != nil {
|
||||
return Metadata{}, coreerr.E("gguf.ReadMetadata", "reading kv count", err)
|
||||
}
|
||||
} else {
|
||||
var tensorCount32, kvCount32 uint32
|
||||
if err := binary.Read(r, binary.LittleEndian, &tensorCount32); err != nil {
|
||||
if err := binary.Read(reader, binary.LittleEndian, &tensorCount32); err != nil {
|
||||
return Metadata{}, coreerr.E("gguf.ReadMetadata", "reading tensor count", err)
|
||||
}
|
||||
if err := binary.Read(r, binary.LittleEndian, &kvCount32); err != nil {
|
||||
if err := binary.Read(reader, binary.LittleEndian, &kvCount32); err != nil {
|
||||
return Metadata{}, coreerr.E("gguf.ReadMetadata", "reading kv count", err)
|
||||
}
|
||||
tensorCount = uint64(tensorCount32)
|
||||
|
|
@ -155,76 +155,76 @@ func ReadMetadata(path string) (Metadata, error) {
|
|||
candidateBlockCount := make(map[string]uint32)
|
||||
|
||||
for i := uint64(0); i < kvCount; i++ {
|
||||
key, err := readString(r)
|
||||
key, err := readString(reader)
|
||||
if err != nil {
|
||||
return Metadata{}, coreerr.E("gguf.ReadMetadata", fmt.Sprintf("reading key %d", i), err)
|
||||
return Metadata{}, coreerr.E("gguf.ReadMetadata", core.Sprintf("reading key %d", i), err)
|
||||
}
|
||||
|
||||
var valType uint32
|
||||
if err := binary.Read(r, binary.LittleEndian, &valType); err != nil {
|
||||
return Metadata{}, coreerr.E("gguf.ReadMetadata", fmt.Sprintf("reading value type for key %q", key), err)
|
||||
if err := binary.Read(reader, binary.LittleEndian, &valType); err != nil {
|
||||
return Metadata{}, coreerr.E("gguf.ReadMetadata", core.Sprintf("reading value type for key %q", key), err)
|
||||
}
|
||||
|
||||
// Check whether this is an interesting key before reading the value.
|
||||
switch {
|
||||
case key == "general.architecture":
|
||||
v, err := readTypedValue(r, valType)
|
||||
rawValue, err := readTypedValue(reader, valType)
|
||||
if err != nil {
|
||||
return Metadata{}, coreerr.E("gguf.ReadMetadata", fmt.Sprintf("reading value for key %q", key), err)
|
||||
return Metadata{}, coreerr.E("gguf.ReadMetadata", core.Sprintf("reading value for key %q", key), err)
|
||||
}
|
||||
if s, ok := v.(string); ok {
|
||||
meta.Architecture = s
|
||||
if stringValue, ok := rawValue.(string); ok {
|
||||
meta.Architecture = stringValue
|
||||
}
|
||||
|
||||
case key == "general.name":
|
||||
v, err := readTypedValue(r, valType)
|
||||
rawValue, err := readTypedValue(reader, valType)
|
||||
if err != nil {
|
||||
return Metadata{}, coreerr.E("gguf.ReadMetadata", fmt.Sprintf("reading value for key %q", key), err)
|
||||
return Metadata{}, coreerr.E("gguf.ReadMetadata", core.Sprintf("reading value for key %q", key), err)
|
||||
}
|
||||
if s, ok := v.(string); ok {
|
||||
meta.Name = s
|
||||
if stringValue, ok := rawValue.(string); ok {
|
||||
meta.Name = stringValue
|
||||
}
|
||||
|
||||
case key == "general.file_type":
|
||||
v, err := readTypedValue(r, valType)
|
||||
rawValue, err := readTypedValue(reader, valType)
|
||||
if err != nil {
|
||||
return Metadata{}, coreerr.E("gguf.ReadMetadata", fmt.Sprintf("reading value for key %q", key), err)
|
||||
return Metadata{}, coreerr.E("gguf.ReadMetadata", core.Sprintf("reading value for key %q", key), err)
|
||||
}
|
||||
if u, ok := v.(uint32); ok {
|
||||
meta.FileType = u
|
||||
if uint32Value, ok := rawValue.(uint32); ok {
|
||||
meta.FileType = uint32Value
|
||||
}
|
||||
|
||||
case key == "general.size_label":
|
||||
v, err := readTypedValue(r, valType)
|
||||
rawValue, err := readTypedValue(reader, valType)
|
||||
if err != nil {
|
||||
return Metadata{}, coreerr.E("gguf.ReadMetadata", fmt.Sprintf("reading value for key %q", key), err)
|
||||
return Metadata{}, coreerr.E("gguf.ReadMetadata", core.Sprintf("reading value for key %q", key), err)
|
||||
}
|
||||
if s, ok := v.(string); ok {
|
||||
meta.SizeLabel = s
|
||||
if stringValue, ok := rawValue.(string); ok {
|
||||
meta.SizeLabel = stringValue
|
||||
}
|
||||
|
||||
case strings.HasSuffix(key, ".context_length"):
|
||||
v, err := readTypedValue(r, valType)
|
||||
case core.HasSuffix(key, ".context_length"):
|
||||
rawValue, err := readTypedValue(reader, valType)
|
||||
if err != nil {
|
||||
return Metadata{}, coreerr.E("gguf.ReadMetadata", fmt.Sprintf("reading value for key %q", key), err)
|
||||
return Metadata{}, coreerr.E("gguf.ReadMetadata", core.Sprintf("reading value for key %q", key), err)
|
||||
}
|
||||
if u, ok := v.(uint32); ok {
|
||||
candidateContextLength[key] = u
|
||||
if uint32Value, ok := rawValue.(uint32); ok {
|
||||
candidateContextLength[key] = uint32Value
|
||||
}
|
||||
|
||||
case strings.HasSuffix(key, ".block_count"):
|
||||
v, err := readTypedValue(r, valType)
|
||||
case core.HasSuffix(key, ".block_count"):
|
||||
rawValue, err := readTypedValue(reader, valType)
|
||||
if err != nil {
|
||||
return Metadata{}, coreerr.E("gguf.ReadMetadata", fmt.Sprintf("reading value for key %q", key), err)
|
||||
return Metadata{}, coreerr.E("gguf.ReadMetadata", core.Sprintf("reading value for key %q", key), err)
|
||||
}
|
||||
if u, ok := v.(uint32); ok {
|
||||
candidateBlockCount[key] = u
|
||||
if uint32Value, ok := rawValue.(uint32); ok {
|
||||
candidateBlockCount[key] = uint32Value
|
||||
}
|
||||
|
||||
default:
|
||||
// Skip uninteresting value.
|
||||
if err := skipValue(r, valType); err != nil {
|
||||
return Metadata{}, coreerr.E("gguf.ReadMetadata", fmt.Sprintf("skipping value for key %q", key), err)
|
||||
if err := skipValue(reader, valType); err != nil {
|
||||
return Metadata{}, coreerr.E("gguf.ReadMetadata", core.Sprintf("skipping value for key %q", key), err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -232,11 +232,11 @@ func ReadMetadata(path string) (Metadata, error) {
|
|||
// Resolve architecture-specific keys.
|
||||
if meta.Architecture != "" {
|
||||
prefix := meta.Architecture + "."
|
||||
if v, ok := candidateContextLength[prefix+"context_length"]; ok {
|
||||
meta.ContextLength = v
|
||||
if contextLength, ok := candidateContextLength[prefix+"context_length"]; ok {
|
||||
meta.ContextLength = contextLength
|
||||
}
|
||||
if v, ok := candidateBlockCount[prefix+"block_count"]; ok {
|
||||
meta.BlockCount = v
|
||||
if blockCount, ok := candidateBlockCount[prefix+"block_count"]; ok {
|
||||
meta.BlockCount = blockCount
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -248,94 +248,94 @@ func ReadMetadata(path string) (Metadata, error) {
|
|||
const maxStringLength = 1 << 20
|
||||
|
||||
// readString reads a GGUF string: uint64 length followed by that many bytes.
|
||||
func readString(r io.Reader) (string, error) {
|
||||
func readString(reader io.Reader) (string, error) {
|
||||
var length uint64
|
||||
if err := binary.Read(r, binary.LittleEndian, &length); err != nil {
|
||||
if err := binary.Read(reader, binary.LittleEndian, &length); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if length > maxStringLength {
|
||||
return "", coreerr.E("gguf.readString", fmt.Sprintf("string length %d exceeds maximum %d", length, maxStringLength), nil)
|
||||
return "", coreerr.E("gguf.readString", core.Sprintf("string length %d exceeds maximum %d", length, maxStringLength), nil)
|
||||
}
|
||||
buf := make([]byte, length)
|
||||
if _, err := io.ReadFull(r, buf); err != nil {
|
||||
buffer := make([]byte, length)
|
||||
if _, err := io.ReadFull(reader, buffer); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(buf), nil
|
||||
return string(buffer), nil
|
||||
}
|
||||
|
||||
// readTypedValue reads a value of the given GGUF type and returns it as a Go
|
||||
// value. String, uint32, and uint64 types return typed values (uint64 is
|
||||
// downcast to uint32 when it fits). All others are read and discarded.
|
||||
func readTypedValue(r io.Reader, valType uint32) (any, error) {
|
||||
func readTypedValue(reader io.Reader, valType uint32) (any, error) {
|
||||
switch valType {
|
||||
case typeString:
|
||||
return readString(r)
|
||||
return readString(reader)
|
||||
case typeUint32:
|
||||
var v uint32
|
||||
err := binary.Read(r, binary.LittleEndian, &v)
|
||||
return v, err
|
||||
var uint32Value uint32
|
||||
err := binary.Read(reader, binary.LittleEndian, &uint32Value)
|
||||
return uint32Value, err
|
||||
case typeUint64:
|
||||
var v uint64
|
||||
if err := binary.Read(r, binary.LittleEndian, &v); err != nil {
|
||||
var uint64Value uint64
|
||||
if err := binary.Read(reader, binary.LittleEndian, &uint64Value); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if v <= math.MaxUint32 {
|
||||
return uint32(v), nil
|
||||
if uint64Value <= math.MaxUint32 {
|
||||
return uint32(uint64Value), nil
|
||||
}
|
||||
return v, nil
|
||||
return uint64Value, nil
|
||||
default:
|
||||
// Read and discard the value, returning nil.
|
||||
err := skipValue(r, valType)
|
||||
err := skipValue(reader, valType)
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// skipValue reads and discards a GGUF value of the given type from r.
|
||||
func skipValue(r io.Reader, valType uint32) error {
|
||||
// skipValue reads and discards a GGUF value of the given type from reader.
|
||||
func skipValue(reader io.Reader, valType uint32) error {
|
||||
switch valType {
|
||||
case typeUint8, typeInt8, typeBool:
|
||||
_, err := readN(r, 1)
|
||||
_, err := readN(reader, 1)
|
||||
return err
|
||||
case typeUint16, typeInt16:
|
||||
_, err := readN(r, 2)
|
||||
_, err := readN(reader, 2)
|
||||
return err
|
||||
case typeUint32, typeInt32, typeFloat32:
|
||||
_, err := readN(r, 4)
|
||||
_, err := readN(reader, 4)
|
||||
return err
|
||||
case typeUint64, typeInt64, typeFloat64:
|
||||
_, err := readN(r, 8)
|
||||
_, err := readN(reader, 8)
|
||||
return err
|
||||
case typeString:
|
||||
var length uint64
|
||||
if err := binary.Read(r, binary.LittleEndian, &length); err != nil {
|
||||
if err := binary.Read(reader, binary.LittleEndian, &length); err != nil {
|
||||
return err
|
||||
}
|
||||
if length > maxStringLength {
|
||||
return coreerr.E("gguf.skipValue", fmt.Sprintf("string length %d exceeds maximum %d", length, maxStringLength), nil)
|
||||
return coreerr.E("gguf.skipValue", core.Sprintf("string length %d exceeds maximum %d", length, maxStringLength), nil)
|
||||
}
|
||||
_, err := readN(r, int64(length))
|
||||
_, err := readN(reader, int64(length))
|
||||
return err
|
||||
case typeArray:
|
||||
var elemType uint32
|
||||
if err := binary.Read(r, binary.LittleEndian, &elemType); err != nil {
|
||||
if err := binary.Read(reader, binary.LittleEndian, &elemType); err != nil {
|
||||
return err
|
||||
}
|
||||
var count uint64
|
||||
if err := binary.Read(r, binary.LittleEndian, &count); err != nil {
|
||||
if err := binary.Read(reader, binary.LittleEndian, &count); err != nil {
|
||||
return err
|
||||
}
|
||||
for i := uint64(0); i < count; i++ {
|
||||
if err := skipValue(r, elemType); err != nil {
|
||||
if err := skipValue(reader, elemType); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
default:
|
||||
return coreerr.E("gguf.skipValue", fmt.Sprintf("unknown GGUF value type: %d", valType), nil)
|
||||
return coreerr.E("gguf.skipValue", core.Sprintf("unknown GGUF value type: %d", valType), nil)
|
||||
}
|
||||
}
|
||||
|
||||
// readN reads and discards exactly n bytes from r.
|
||||
func readN(r io.Reader, n int64) (int64, error) {
|
||||
return io.CopyN(io.Discard, r, n)
|
||||
// readN reads and discards exactly n bytes from reader.
|
||||
func readN(reader io.Reader, n int64) (int64, error) {
|
||||
return io.CopyN(io.Discard, reader, n)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,9 +3,10 @@ package gguf
|
|||
import (
|
||||
"encoding/binary"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"dappco.re/go/core"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
|
@ -16,7 +17,7 @@ func writeTestGGUFOrdered(t *testing.T, kvs [][2]any) string {
|
|||
t.Helper()
|
||||
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "test.gguf")
|
||||
path := core.Path(dir, "test.gguf")
|
||||
|
||||
f, err := os.Create(path)
|
||||
require.NoError(t, err)
|
||||
|
|
@ -86,7 +87,7 @@ func writeTestGGUFV2(t *testing.T, kvs [][2]any) string {
|
|||
t.Helper()
|
||||
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "test_v2.gguf")
|
||||
path := core.Path(dir, "test_v2.gguf")
|
||||
|
||||
f, err := os.Create(path)
|
||||
require.NoError(t, err)
|
||||
|
|
@ -176,7 +177,7 @@ func TestReadMetadata_Ugly_ArchAfterContextLength(t *testing.T) {
|
|||
|
||||
func TestReadMetadata_Bad_InvalidMagic(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "notgguf.bin")
|
||||
path := core.Path(dir, "notgguf.bin")
|
||||
|
||||
err := os.WriteFile(path, []byte("this is not a GGUF file at all"), 0644)
|
||||
require.NoError(t, err)
|
||||
|
|
@ -196,7 +197,17 @@ func TestFileTypeName_Good(t *testing.T) {
|
|||
assert.Equal(t, "Q5_K_M", FileTypeName(17))
|
||||
assert.Equal(t, "Q8_0", FileTypeName(7))
|
||||
assert.Equal(t, "F16", FileTypeName(1))
|
||||
assert.Equal(t, "type_999", FileTypeName(999))
|
||||
}
|
||||
|
||||
func TestFileTypeName_Bad_UnknownType(t *testing.T) {
|
||||
// Unknown type numbers must not panic; return "type_N".
|
||||
name := FileTypeName(999)
|
||||
assert.Equal(t, "type_999", name)
|
||||
}
|
||||
|
||||
func TestFileTypeName_Ugly_ZeroType(t *testing.T) {
|
||||
// Type 0 is F32 — the zero value must map to a valid name, not "type_0".
|
||||
assert.Equal(t, "F32", FileTypeName(0))
|
||||
}
|
||||
|
||||
func TestReadMetadata_Good_V2(t *testing.T) {
|
||||
|
|
@ -221,7 +232,7 @@ func TestReadMetadata_Good_V2(t *testing.T) {
|
|||
|
||||
func TestReadMetadata_Bad_UnsupportedVersion(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "bad_version.gguf")
|
||||
path := core.Path(dir, "bad_version.gguf")
|
||||
|
||||
f, err := os.Create(path)
|
||||
require.NoError(t, err)
|
||||
|
|
@ -239,7 +250,7 @@ func TestReadMetadata_Ugly_SkipsUnknownValueTypes(t *testing.T) {
|
|||
// Tests skipValue for uint8, int16, float32, uint64, bool, and array types.
|
||||
// These are stored under uninteresting keys so ReadMetadata skips them.
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "skip_types.gguf")
|
||||
path := core.Path(dir, "skip_types.gguf")
|
||||
|
||||
f, err := os.Create(path)
|
||||
require.NoError(t, err)
|
||||
|
|
@ -317,7 +328,7 @@ func TestReadMetadata_Ugly_Uint64ContextLength(t *testing.T) {
|
|||
|
||||
func TestReadMetadata_Bad_TruncatedFile(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "truncated.gguf")
|
||||
path := core.Path(dir, "truncated.gguf")
|
||||
|
||||
// Write only the magic — no version or counts.
|
||||
f, err := os.Create(path)
|
||||
|
|
@ -333,7 +344,7 @@ func TestReadMetadata_Bad_TruncatedFile(t *testing.T) {
|
|||
func TestReadMetadata_Ugly_SkipsStringValue(t *testing.T) {
|
||||
// Tests skipValue for string type (type 8) on an uninteresting key.
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "skip_string.gguf")
|
||||
path := core.Path(dir, "skip_string.gguf")
|
||||
|
||||
f, err := os.Create(path)
|
||||
require.NoError(t, err)
|
||||
|
|
|
|||
|
|
@ -4,14 +4,13 @@ import (
|
|||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"iter"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"dappco.re/go/core"
|
||||
|
||||
coreerr "forge.lthn.ai/core/go-log"
|
||||
)
|
||||
|
||||
|
|
@ -60,55 +59,56 @@ type completionChunkResponse struct {
|
|||
}
|
||||
|
||||
// ChatComplete sends a streaming chat completion request to /v1/chat/completions.
|
||||
// Returns an iterator over text chunks and an error accessor called after ranging.
|
||||
// It returns an iterator over text chunks and a function that returns any error
|
||||
// that occurred during the request or while reading the stream.
|
||||
//
|
||||
// chunks, errFn := client.ChatComplete(ctx, llamacpp.ChatRequest{
|
||||
// Messages: []llamacpp.ChatMessage{{Role: "user", Content: "Hello"}},
|
||||
// MaxTokens: 128, Temperature: 0.7,
|
||||
// })
|
||||
// for text := range chunks { fmt.Print(text) }
|
||||
// if err := errFn(); err != nil { /* handle */ }
|
||||
func (c *Client) ChatComplete(ctx context.Context, req ChatRequest) (iter.Seq[string], func() error) {
|
||||
req.Stream = true
|
||||
// chunks, errorFunc := client.ChatComplete(ctx, ChatRequest{Messages: msgs, Temperature: 0.7})
|
||||
// for text := range chunks { output += text }
|
||||
// if err := errorFunc(); err != nil { /* handle */ }
|
||||
func (c *Client) ChatComplete(ctx context.Context, chatRequest ChatRequest) (iter.Seq[string], func() error) {
|
||||
chatRequest.Stream = true
|
||||
|
||||
body, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return noChunks, func() error { return coreerr.E("llamacpp.ChatComplete", "marshal chat request", err) }
|
||||
marshalResult := core.JSONMarshal(chatRequest)
|
||||
if !marshalResult.OK {
|
||||
marshalErr := marshalResult.Value.(error)
|
||||
return noChunks, func() error { return coreerr.E("llamacpp.ChatComplete", "marshal chat request", marshalErr) }
|
||||
}
|
||||
body := marshalResult.Value.([]byte)
|
||||
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+"/v1/chat/completions", bytes.NewReader(body))
|
||||
httpRequest, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+"/v1/chat/completions", bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return noChunks, func() error { return coreerr.E("llamacpp.ChatComplete", "create chat request", err) }
|
||||
}
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
httpReq.Header.Set("Accept", "text/event-stream")
|
||||
httpRequest.Header.Set("Content-Type", "application/json")
|
||||
httpRequest.Header.Set("Accept", "text/event-stream")
|
||||
|
||||
resp, err := c.httpClient.Do(httpReq)
|
||||
response, err := c.httpClient.Do(httpRequest)
|
||||
if err != nil {
|
||||
return noChunks, func() error { return coreerr.E("llamacpp.ChatComplete", "chat request", err) }
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
defer resp.Body.Close()
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 256))
|
||||
if response.StatusCode != http.StatusOK {
|
||||
defer response.Body.Close()
|
||||
responseBody, _ := io.ReadAll(io.LimitReader(response.Body, 256))
|
||||
return noChunks, func() error {
|
||||
return coreerr.E("llamacpp.ChatComplete", fmt.Sprintf("chat returned %d: %s", resp.StatusCode, strings.TrimSpace(string(respBody))), nil)
|
||||
return coreerr.E("llamacpp.ChatComplete", core.Sprintf("chat returned %d: %s", response.StatusCode, core.Trim(string(responseBody))), nil)
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
streamErr error
|
||||
closeOnce sync.Once
|
||||
closeBody = func() { closeOnce.Do(func() { resp.Body.Close() }) }
|
||||
closeBody = func() { closeOnce.Do(func() { response.Body.Close() }) }
|
||||
)
|
||||
sseData := parseSSE(resp.Body, &streamErr)
|
||||
sseData := parseSSE(response.Body, &streamErr)
|
||||
|
||||
tokens := func(yield func(string) bool) {
|
||||
defer closeBody()
|
||||
for raw := range sseData {
|
||||
for ssePayload := range sseData {
|
||||
var chunk chatChunkResponse
|
||||
if err := json.Unmarshal([]byte(raw), &chunk); err != nil {
|
||||
streamErr = coreerr.E("llamacpp.ChatComplete", "decode chat chunk", err)
|
||||
decodeResult := core.JSONUnmarshal([]byte(ssePayload), &chunk)
|
||||
if !decodeResult.OK {
|
||||
streamErr = coreerr.E("llamacpp.ChatComplete", "decode chat chunk", decodeResult.Value.(error))
|
||||
return
|
||||
}
|
||||
if len(chunk.Choices) == 0 {
|
||||
|
|
@ -131,54 +131,56 @@ func (c *Client) ChatComplete(ctx context.Context, req ChatRequest) (iter.Seq[st
|
|||
}
|
||||
|
||||
// Complete sends a streaming completion request to /v1/completions.
|
||||
// Returns an iterator over text chunks and an error accessor called after ranging.
|
||||
// It returns an iterator over text chunks and a function that returns any error
|
||||
// that occurred during the request or while reading the stream.
|
||||
//
|
||||
// chunks, errFn := client.Complete(ctx, llamacpp.CompletionRequest{
|
||||
// Prompt: "The capital of France is", MaxTokens: 16, Temperature: 0.0,
|
||||
// })
|
||||
// for text := range chunks { fmt.Print(text) }
|
||||
// if err := errFn(); err != nil { /* handle */ }
|
||||
func (c *Client) Complete(ctx context.Context, req CompletionRequest) (iter.Seq[string], func() error) {
|
||||
req.Stream = true
|
||||
// chunks, errorFunc := client.Complete(ctx, CompletionRequest{Prompt: "Once upon", Temperature: 0.8})
|
||||
// for text := range chunks { output += text }
|
||||
// if err := errorFunc(); err != nil { /* handle */ }
|
||||
func (c *Client) Complete(ctx context.Context, completionRequest CompletionRequest) (iter.Seq[string], func() error) {
|
||||
completionRequest.Stream = true
|
||||
|
||||
body, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return noChunks, func() error { return coreerr.E("llamacpp.Complete", "marshal completion request", err) }
|
||||
marshalResult := core.JSONMarshal(completionRequest)
|
||||
if !marshalResult.OK {
|
||||
marshalErr := marshalResult.Value.(error)
|
||||
return noChunks, func() error { return coreerr.E("llamacpp.Complete", "marshal completion request", marshalErr) }
|
||||
}
|
||||
body := marshalResult.Value.([]byte)
|
||||
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+"/v1/completions", bytes.NewReader(body))
|
||||
httpRequest, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+"/v1/completions", bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return noChunks, func() error { return coreerr.E("llamacpp.Complete", "create completion request", err) }
|
||||
}
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
httpReq.Header.Set("Accept", "text/event-stream")
|
||||
httpRequest.Header.Set("Content-Type", "application/json")
|
||||
httpRequest.Header.Set("Accept", "text/event-stream")
|
||||
|
||||
resp, err := c.httpClient.Do(httpReq)
|
||||
response, err := c.httpClient.Do(httpRequest)
|
||||
if err != nil {
|
||||
return noChunks, func() error { return coreerr.E("llamacpp.Complete", "completion request", err) }
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
defer resp.Body.Close()
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 256))
|
||||
if response.StatusCode != http.StatusOK {
|
||||
defer response.Body.Close()
|
||||
responseBody, _ := io.ReadAll(io.LimitReader(response.Body, 256))
|
||||
return noChunks, func() error {
|
||||
return coreerr.E("llamacpp.Complete", fmt.Sprintf("completion returned %d: %s", resp.StatusCode, strings.TrimSpace(string(respBody))), nil)
|
||||
return coreerr.E("llamacpp.Complete", core.Sprintf("completion returned %d: %s", response.StatusCode, core.Trim(string(responseBody))), nil)
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
streamErr error
|
||||
closeOnce sync.Once
|
||||
closeBody = func() { closeOnce.Do(func() { resp.Body.Close() }) }
|
||||
closeBody = func() { closeOnce.Do(func() { response.Body.Close() }) }
|
||||
)
|
||||
sseData := parseSSE(resp.Body, &streamErr)
|
||||
sseData := parseSSE(response.Body, &streamErr)
|
||||
|
||||
tokens := func(yield func(string) bool) {
|
||||
defer closeBody()
|
||||
for raw := range sseData {
|
||||
for ssePayload := range sseData {
|
||||
var chunk completionChunkResponse
|
||||
if err := json.Unmarshal([]byte(raw), &chunk); err != nil {
|
||||
streamErr = coreerr.E("llamacpp.Complete", "decode completion chunk", err)
|
||||
decodeResult := core.JSONUnmarshal([]byte(ssePayload), &chunk)
|
||||
if !decodeResult.OK {
|
||||
streamErr = coreerr.E("llamacpp.Complete", "decode completion chunk", decodeResult.Value.(error))
|
||||
return
|
||||
}
|
||||
if len(chunk.Choices) == 0 {
|
||||
|
|
@ -200,18 +202,18 @@ func (c *Client) Complete(ctx context.Context, req CompletionRequest) (iter.Seq[
|
|||
}
|
||||
}
|
||||
|
||||
// parseSSE reads SSE-formatted lines from r and yields the payload of each
|
||||
// parseSSE reads SSE-formatted lines from reader and yields the payload of each
|
||||
// "data: " line. It stops when it encounters "[DONE]" or an I/O error.
|
||||
// Any read error (other than EOF) is stored via errOut.
|
||||
func parseSSE(r io.Reader, errOut *error) iter.Seq[string] {
|
||||
func parseSSE(reader io.Reader, errOut *error) iter.Seq[string] {
|
||||
return func(yield func(string) bool) {
|
||||
scanner := bufio.NewScanner(r)
|
||||
scanner := bufio.NewScanner(reader)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if !strings.HasPrefix(line, "data: ") {
|
||||
if !core.HasPrefix(line, "data: ") {
|
||||
continue
|
||||
}
|
||||
payload := strings.TrimPrefix(line, "data: ")
|
||||
payload := core.TrimPrefix(line, "data: ")
|
||||
if payload == "[DONE]" {
|
||||
return
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ package llamacpp
|
|||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
|
@ -22,7 +22,7 @@ func sseLines(w http.ResponseWriter, lines []string) {
|
|||
w.WriteHeader(http.StatusOK)
|
||||
|
||||
for _, line := range lines {
|
||||
fmt.Fprintf(w, "data: %s\n\n", line)
|
||||
_, _ = io.WriteString(w, "data: "+line+"\n\n")
|
||||
f.Flush()
|
||||
}
|
||||
}
|
||||
|
|
@ -114,7 +114,7 @@ func TestChatComplete_Ugly_ContextCancelled(t *testing.T) {
|
|||
w.WriteHeader(http.StatusOK)
|
||||
|
||||
// Send first chunk.
|
||||
fmt.Fprintf(w, "data: %s\n\n", `{"choices":[{"delta":{"content":"Hello"},"finish_reason":null}]}`)
|
||||
_, _ = io.WriteString(w, "data: "+`{"choices":[{"delta":{"content":"Hello"},"finish_reason":null}]}`+"\n\n")
|
||||
f.Flush()
|
||||
|
||||
// Wait for context cancellation before sending more.
|
||||
|
|
@ -192,3 +192,43 @@ func TestComplete_Bad_HTTPError(t *testing.T) {
|
|||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "400")
|
||||
}
|
||||
|
||||
func TestComplete_Ugly_ContextCancelled(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
f, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
panic("ResponseWriter does not implement Flusher")
|
||||
}
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.Header().Set("Cache-Control", "no-cache")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
|
||||
// Send first chunk.
|
||||
_, _ = io.WriteString(w, "data: "+`{"choices":[{"text":"Once","finish_reason":null}]}`+"\n\n")
|
||||
f.Flush()
|
||||
|
||||
// Wait for context cancellation before sending more.
|
||||
<-r.Context().Done()
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
c := NewClient(ts.URL)
|
||||
tokens, errFn := c.Complete(ctx, CompletionRequest{
|
||||
Prompt: "Once",
|
||||
Temperature: 0.7,
|
||||
Stream: true,
|
||||
})
|
||||
|
||||
var got []string
|
||||
for tok := range tokens {
|
||||
got = append(got, tok)
|
||||
cancel() // Cancel after receiving the first token.
|
||||
}
|
||||
// The error may or may not be nil depending on timing;
|
||||
// the important thing is we got exactly 1 token.
|
||||
_ = errFn()
|
||||
assert.Equal(t, []string{"Once"}, got)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,11 +2,10 @@ package llamacpp
|
|||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"dappco.re/go/core"
|
||||
|
||||
coreerr "forge.lthn.ai/core/go-log"
|
||||
)
|
||||
|
|
@ -20,12 +19,9 @@ type Client struct {
|
|||
// NewClient creates a client for the llama-server at the given base URL.
|
||||
//
|
||||
// client := llamacpp.NewClient("http://127.0.0.1:8080")
|
||||
// if err := client.Health(ctx); err != nil {
|
||||
// // server not ready
|
||||
// }
|
||||
func NewClient(baseURL string) *Client {
|
||||
return &Client{
|
||||
baseURL: strings.TrimRight(baseURL, "/"),
|
||||
baseURL: core.TrimSuffix(baseURL, "/"),
|
||||
httpClient: &http.Client{},
|
||||
}
|
||||
}
|
||||
|
|
@ -35,32 +31,36 @@ type healthResponse struct {
|
|||
}
|
||||
|
||||
// Health checks whether the llama-server is ready to accept requests.
|
||||
// Returns nil when the server responds with {"status":"ok"}.
|
||||
//
|
||||
// if err := client.Health(ctx); err != nil {
|
||||
// log.Printf("server not ready: %v", err)
|
||||
// }
|
||||
// if err := client.Health(ctx); err != nil { /* server not ready */ }
|
||||
func (c *Client) Health(ctx context.Context) error {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.baseURL+"/health", nil)
|
||||
request, err := http.NewRequestWithContext(ctx, http.MethodGet, c.baseURL+"/health", nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
resp, err := c.httpClient.Do(req)
|
||||
response, err := c.httpClient.Do(request)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
defer response.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 256))
|
||||
return coreerr.E("llamacpp.Health", fmt.Sprintf("health returned %d: %s", resp.StatusCode, string(body)), nil)
|
||||
if response.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(io.LimitReader(response.Body, 256))
|
||||
return coreerr.E("llamacpp.Health", core.Sprintf("health returned %d: %s", response.StatusCode, string(body)), nil)
|
||||
}
|
||||
var h healthResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&h); err != nil {
|
||||
return coreerr.E("llamacpp.Health", "health decode", err)
|
||||
var healthStatus healthResponse
|
||||
decodeResult := core.JSONUnmarshal(mustReadAll(response.Body), &healthStatus)
|
||||
if !decodeResult.OK {
|
||||
return coreerr.E("llamacpp.Health", "health decode", decodeResult.Value.(error))
|
||||
}
|
||||
if h.Status != "ok" {
|
||||
return coreerr.E("llamacpp.Health", fmt.Sprintf("server not ready (status: %s)", h.Status), nil)
|
||||
if healthStatus.Status != "ok" {
|
||||
return coreerr.E("llamacpp.Health", core.Sprintf("server not ready (status: %s)", healthStatus.Status), nil)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// mustReadAll reads all bytes from reader, returning nil on error.
|
||||
func mustReadAll(reader io.Reader) []byte {
|
||||
data, _ := io.ReadAll(reader)
|
||||
return data
|
||||
}
|
||||
|
|
|
|||
|
|
@ -53,3 +53,15 @@ func TestHealth_Bad_ServerDown(t *testing.T) {
|
|||
err := c.Health(context.Background())
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestHealth_Ugly_MalformedJSON(t *testing.T) {
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Write([]byte(`not valid json`))
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
c := NewClient(ts.URL)
|
||||
err := c.Health(context.Background())
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
|
|
|||
135
model.go
135
model.go
|
|
@ -4,12 +4,12 @@ package rocm
|
|||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"iter"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"dappco.re/go/core"
|
||||
|
||||
coreerr "forge.lthn.ai/core/go-log"
|
||||
"forge.lthn.ai/core/go-inference"
|
||||
"forge.lthn.ai/core/go-rocm/internal/llamacpp"
|
||||
|
|
@ -27,6 +27,10 @@ type rocmModel struct {
|
|||
}
|
||||
|
||||
// Generate streams tokens for the given prompt via llama-server's /v1/completions endpoint.
|
||||
//
|
||||
// for tok := range m.Generate(ctx, "The capital of France is", inference.WithMaxTokens(32)) {
|
||||
// output += tok.Text
|
||||
// }
|
||||
func (m *rocmModel) Generate(ctx context.Context, prompt string, opts ...inference.GenerateOption) iter.Seq[inference.Token] {
|
||||
m.mu.Lock()
|
||||
m.lastErr = nil
|
||||
|
|
@ -37,39 +41,44 @@ func (m *rocmModel) Generate(ctx context.Context, prompt string, opts ...inferen
|
|||
return func(yield func(inference.Token) bool) {}
|
||||
}
|
||||
|
||||
config := inference.ApplyGenerateOpts(opts)
|
||||
configuration := inference.ApplyGenerateOpts(opts)
|
||||
|
||||
req := llamacpp.CompletionRequest{
|
||||
completionRequest := llamacpp.CompletionRequest{
|
||||
Prompt: prompt,
|
||||
MaxTokens: config.MaxTokens,
|
||||
Temperature: config.Temperature,
|
||||
TopK: config.TopK,
|
||||
TopP: config.TopP,
|
||||
RepeatPenalty: config.RepeatPenalty,
|
||||
MaxTokens: configuration.MaxTokens,
|
||||
Temperature: configuration.Temperature,
|
||||
TopK: configuration.TopK,
|
||||
TopP: configuration.TopP,
|
||||
RepeatPenalty: configuration.RepeatPenalty,
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
chunks, errFn := m.server.client.Complete(ctx, req)
|
||||
chunks, errorFunc := m.server.client.Complete(ctx, completionRequest)
|
||||
|
||||
return func(yield func(inference.Token) bool) {
|
||||
var count int
|
||||
var tokenCount int
|
||||
decodeStart := time.Now()
|
||||
for text := range chunks {
|
||||
count++
|
||||
tokenCount++
|
||||
if !yield(inference.Token{Text: text}) {
|
||||
break
|
||||
}
|
||||
}
|
||||
if err := errFn(); err != nil {
|
||||
if err := errorFunc(); err != nil {
|
||||
m.mu.Lock()
|
||||
m.lastErr = err
|
||||
m.mu.Unlock()
|
||||
}
|
||||
m.recordMetrics(0, count, start, decodeStart)
|
||||
m.recordMetrics(0, tokenCount, start, decodeStart)
|
||||
}
|
||||
}
|
||||
|
||||
// Chat streams tokens from a multi-turn conversation via llama-server's /v1/chat/completions endpoint.
|
||||
//
|
||||
// msgs := []inference.Message{{Role: "user", Content: "Hello"}}
|
||||
// for tok := range m.Chat(ctx, msgs, inference.WithMaxTokens(64)) {
|
||||
// output += tok.Text
|
||||
// }
|
||||
func (m *rocmModel) Chat(ctx context.Context, messages []inference.Message, opts ...inference.GenerateOption) iter.Seq[inference.Token] {
|
||||
m.mu.Lock()
|
||||
m.lastErr = nil
|
||||
|
|
@ -80,49 +89,52 @@ func (m *rocmModel) Chat(ctx context.Context, messages []inference.Message, opts
|
|||
return func(yield func(inference.Token) bool) {}
|
||||
}
|
||||
|
||||
config := inference.ApplyGenerateOpts(opts)
|
||||
configuration := inference.ApplyGenerateOpts(opts)
|
||||
|
||||
chatMsgs := make([]llamacpp.ChatMessage, len(messages))
|
||||
chatMessages := make([]llamacpp.ChatMessage, len(messages))
|
||||
for i, msg := range messages {
|
||||
chatMsgs[i] = llamacpp.ChatMessage{
|
||||
chatMessages[i] = llamacpp.ChatMessage{
|
||||
Role: msg.Role,
|
||||
Content: msg.Content,
|
||||
}
|
||||
}
|
||||
|
||||
req := llamacpp.ChatRequest{
|
||||
Messages: chatMsgs,
|
||||
MaxTokens: config.MaxTokens,
|
||||
Temperature: config.Temperature,
|
||||
TopK: config.TopK,
|
||||
TopP: config.TopP,
|
||||
RepeatPenalty: config.RepeatPenalty,
|
||||
chatRequest := llamacpp.ChatRequest{
|
||||
Messages: chatMessages,
|
||||
MaxTokens: configuration.MaxTokens,
|
||||
Temperature: configuration.Temperature,
|
||||
TopK: configuration.TopK,
|
||||
TopP: configuration.TopP,
|
||||
RepeatPenalty: configuration.RepeatPenalty,
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
chunks, errFn := m.server.client.ChatComplete(ctx, req)
|
||||
chunks, errorFunc := m.server.client.ChatComplete(ctx, chatRequest)
|
||||
|
||||
return func(yield func(inference.Token) bool) {
|
||||
var count int
|
||||
var tokenCount int
|
||||
decodeStart := time.Now()
|
||||
for text := range chunks {
|
||||
count++
|
||||
tokenCount++
|
||||
if !yield(inference.Token{Text: text}) {
|
||||
break
|
||||
}
|
||||
}
|
||||
if err := errFn(); err != nil {
|
||||
if err := errorFunc(); err != nil {
|
||||
m.mu.Lock()
|
||||
m.lastErr = err
|
||||
m.mu.Unlock()
|
||||
}
|
||||
m.recordMetrics(0, count, start, decodeStart)
|
||||
m.recordMetrics(0, tokenCount, start, decodeStart)
|
||||
}
|
||||
}
|
||||
|
||||
// Classify runs batched prefill-only inference via llama-server.
|
||||
// Each prompt gets a single-token completion (max_tokens=1, temperature=0).
|
||||
// llama-server has no native classify endpoint, so this simulates it.
|
||||
//
|
||||
// results, err := m.Classify(ctx, []string{"positive review", "negative review"})
|
||||
// // results[0].Token.Text == "pos" or similar top token
|
||||
func (m *rocmModel) Classify(ctx context.Context, prompts []string, opts ...inference.GenerateOption) ([]inference.ClassifyResult, error) {
|
||||
if !m.server.alive() {
|
||||
m.setServerExitErr()
|
||||
|
|
@ -137,23 +149,23 @@ func (m *rocmModel) Classify(ctx context.Context, prompts []string, opts ...infe
|
|||
return nil, ctx.Err()
|
||||
}
|
||||
|
||||
req := llamacpp.CompletionRequest{
|
||||
completionRequest := llamacpp.CompletionRequest{
|
||||
Prompt: prompt,
|
||||
MaxTokens: 1,
|
||||
Temperature: 0,
|
||||
}
|
||||
|
||||
chunks, errFn := m.server.client.Complete(ctx, req)
|
||||
var text strings.Builder
|
||||
chunks, errorFunc := m.server.client.Complete(ctx, completionRequest)
|
||||
builder := core.NewBuilder()
|
||||
for chunk := range chunks {
|
||||
text.WriteString(chunk)
|
||||
builder.WriteString(chunk)
|
||||
}
|
||||
if err := errFn(); err != nil {
|
||||
return nil, coreerr.E("rocm.Classify", fmt.Sprintf("classify prompt %d", i), err)
|
||||
if err := errorFunc(); err != nil {
|
||||
return nil, coreerr.E("rocm.Classify", core.Sprintf("classify prompt %d", i), err)
|
||||
}
|
||||
|
||||
results[i] = inference.ClassifyResult{
|
||||
Token: inference.Token{Text: text.String()},
|
||||
Token: inference.Token{Text: builder.String()},
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -163,13 +175,16 @@ func (m *rocmModel) Classify(ctx context.Context, prompts []string, opts ...infe
|
|||
|
||||
// BatchGenerate runs batched autoregressive generation via llama-server.
|
||||
// Each prompt is decoded sequentially up to MaxTokens.
|
||||
//
|
||||
// results, err := m.BatchGenerate(ctx, []string{"prompt A", "prompt B"}, inference.WithMaxTokens(64))
|
||||
// // results[0].Tokens — tokens generated for prompt A
|
||||
func (m *rocmModel) BatchGenerate(ctx context.Context, prompts []string, opts ...inference.GenerateOption) ([]inference.BatchResult, error) {
|
||||
if !m.server.alive() {
|
||||
m.setServerExitErr()
|
||||
return nil, m.Err()
|
||||
}
|
||||
|
||||
config := inference.ApplyGenerateOpts(opts)
|
||||
configuration := inference.ApplyGenerateOpts(opts)
|
||||
start := time.Now()
|
||||
results := make([]inference.BatchResult, len(prompts))
|
||||
var totalGenerated int
|
||||
|
|
@ -180,22 +195,22 @@ func (m *rocmModel) BatchGenerate(ctx context.Context, prompts []string, opts ..
|
|||
continue
|
||||
}
|
||||
|
||||
req := llamacpp.CompletionRequest{
|
||||
completionRequest := llamacpp.CompletionRequest{
|
||||
Prompt: prompt,
|
||||
MaxTokens: config.MaxTokens,
|
||||
Temperature: config.Temperature,
|
||||
TopK: config.TopK,
|
||||
TopP: config.TopP,
|
||||
RepeatPenalty: config.RepeatPenalty,
|
||||
MaxTokens: configuration.MaxTokens,
|
||||
Temperature: configuration.Temperature,
|
||||
TopK: configuration.TopK,
|
||||
TopP: configuration.TopP,
|
||||
RepeatPenalty: configuration.RepeatPenalty,
|
||||
}
|
||||
|
||||
chunks, errFn := m.server.client.Complete(ctx, req)
|
||||
chunks, errorFunc := m.server.client.Complete(ctx, completionRequest)
|
||||
var tokens []inference.Token
|
||||
for text := range chunks {
|
||||
tokens = append(tokens, inference.Token{Text: text})
|
||||
}
|
||||
if err := errFn(); err != nil {
|
||||
results[i].Err = coreerr.E("rocm.BatchGenerate", fmt.Sprintf("batch prompt %d", i), err)
|
||||
if err := errorFunc(); err != nil {
|
||||
results[i].Err = coreerr.E("rocm.BatchGenerate", core.Sprintf("batch prompt %d", i), err)
|
||||
}
|
||||
results[i].Tokens = tokens
|
||||
totalGenerated += len(tokens)
|
||||
|
|
@ -206,12 +221,20 @@ func (m *rocmModel) BatchGenerate(ctx context.Context, prompts []string, opts ..
|
|||
}
|
||||
|
||||
// ModelType returns the architecture identifier (e.g. "gemma3", "qwen3", "llama3").
|
||||
//
|
||||
// arch := m.ModelType() // "gemma3"
|
||||
func (m *rocmModel) ModelType() string { return m.modelType }
|
||||
|
||||
// Info returns metadata about the loaded model.
|
||||
//
|
||||
// info := m.Info()
|
||||
// // info.Architecture == "gemma3", info.NumLayers == 26
|
||||
func (m *rocmModel) Info() inference.ModelInfo { return m.modelInfo }
|
||||
|
||||
// Metrics returns performance metrics from the last inference operation.
|
||||
//
|
||||
// metrics := m.Metrics()
|
||||
// // metrics.DecodeTokensPerSec, metrics.TotalDuration
|
||||
func (m *rocmModel) Metrics() inference.GenerateMetrics {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
|
@ -219,6 +242,9 @@ func (m *rocmModel) Metrics() inference.GenerateMetrics {
|
|||
}
|
||||
|
||||
// Err returns the error from the last Generate/Chat call, if any.
|
||||
//
|
||||
// for tok := range m.Generate(ctx, prompt) { }
|
||||
// if err := m.Err(); err != nil { /* handle */ }
|
||||
func (m *rocmModel) Err() error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
|
@ -226,6 +252,9 @@ func (m *rocmModel) Err() error {
|
|||
}
|
||||
|
||||
// Close releases the llama-server subprocess and all associated resources.
|
||||
//
|
||||
// m, err := backend.LoadModel("/data/model.gguf")
|
||||
// defer m.Close()
|
||||
func (m *rocmModel) Close() error {
|
||||
return m.server.stop()
|
||||
}
|
||||
|
|
@ -248,7 +277,7 @@ func (m *rocmModel) recordMetrics(promptTokens, generatedTokens int, start, deco
|
|||
decode := now.Sub(decodeStart)
|
||||
prefill := total - decode
|
||||
|
||||
metrics := inference.GenerateMetrics{
|
||||
result := inference.GenerateMetrics{
|
||||
PromptTokens: promptTokens,
|
||||
GeneratedTokens: generatedTokens,
|
||||
PrefillDuration: prefill,
|
||||
|
|
@ -256,19 +285,19 @@ func (m *rocmModel) recordMetrics(promptTokens, generatedTokens int, start, deco
|
|||
TotalDuration: total,
|
||||
}
|
||||
if prefill > 0 && promptTokens > 0 {
|
||||
metrics.PrefillTokensPerSec = float64(promptTokens) / prefill.Seconds()
|
||||
result.PrefillTokensPerSec = float64(promptTokens) / prefill.Seconds()
|
||||
}
|
||||
if decode > 0 && generatedTokens > 0 {
|
||||
metrics.DecodeTokensPerSec = float64(generatedTokens) / decode.Seconds()
|
||||
result.DecodeTokensPerSec = float64(generatedTokens) / decode.Seconds()
|
||||
}
|
||||
|
||||
// Try to get VRAM stats — best effort.
|
||||
if vram, err := GetVRAMInfo(); err == nil {
|
||||
metrics.PeakMemoryBytes = vram.Used
|
||||
metrics.ActiveMemoryBytes = vram.Used
|
||||
if vramInfo, err := GetVRAMInfo(); err == nil {
|
||||
result.PeakMemoryBytes = vramInfo.Used
|
||||
result.ActiveMemoryBytes = vramInfo.Used
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
m.metrics = metrics
|
||||
m.metrics = result
|
||||
m.mu.Unlock()
|
||||
}
|
||||
|
|
|
|||
7
rocm.go
7
rocm.go
|
|
@ -12,8 +12,9 @@
|
|||
//
|
||||
// m, err := inference.LoadModel("/path/to/model.gguf")
|
||||
// defer m.Close()
|
||||
// output := core.NewBuilder()
|
||||
// for tok := range m.Generate(ctx, "Hello", inference.WithMaxTokens(128)) {
|
||||
// fmt.Print(tok.Text)
|
||||
// output.WriteString(tok.Text)
|
||||
// }
|
||||
//
|
||||
// # Requirements
|
||||
|
|
@ -27,7 +28,7 @@ package rocm
|
|||
// VRAMInfo reports GPU video memory usage in bytes.
|
||||
//
|
||||
// info, err := rocm.GetVRAMInfo()
|
||||
// fmt.Printf("VRAM: %d MiB used / %d MiB total", info.Used/(1024*1024), info.Total/(1024*1024))
|
||||
// core.Print(c, "VRAM: %d MiB used / %d MiB total", info.Used/(1024*1024), info.Total/(1024*1024))
|
||||
type VRAMInfo struct {
|
||||
Total uint64
|
||||
Used uint64
|
||||
|
|
@ -38,7 +39,7 @@ type VRAMInfo struct {
|
|||
//
|
||||
// models, _ := rocm.DiscoverModels("/data/lem/gguf")
|
||||
// for _, m := range models {
|
||||
// fmt.Printf("%s (%s %s, ctx=%d)\n", m.Name, m.Architecture, m.Quantisation, m.ContextLen)
|
||||
// core.Print(c, "%s (%s %s, ctx=%d)", m.Name, m.Architecture, m.Quantisation, m.ContextLen)
|
||||
// }
|
||||
type ModelInfo struct {
|
||||
Path string // full path to .gguf file
|
||||
|
|
|
|||
|
|
@ -5,12 +5,12 @@ package rocm
|
|||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"dappco.re/go/core"
|
||||
|
||||
"forge.lthn.ai/core/go-inference"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
|
@ -35,45 +35,45 @@ func skipIfNoROCm(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestROCm_LoadAndGenerate(t *testing.T) {
|
||||
func TestROCm_LoadAndGenerate_Good(t *testing.T) {
|
||||
skipIfNoROCm(t)
|
||||
skipIfNoModel(t)
|
||||
|
||||
b := &rocmBackend{}
|
||||
require.True(t, b.Available())
|
||||
backend := &rocmBackend{}
|
||||
require.True(t, backend.Available())
|
||||
|
||||
m, err := b.LoadModel(testModel, inference.WithContextLen(2048))
|
||||
model, err := backend.LoadModel(testModel, inference.WithContextLen(2048))
|
||||
require.NoError(t, err)
|
||||
defer m.Close()
|
||||
defer model.Close()
|
||||
|
||||
assert.Equal(t, "gemma3", m.ModelType())
|
||||
assert.Equal(t, "gemma3", model.ModelType())
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
var tokens []string
|
||||
for tok := range m.Generate(ctx, "The capital of France is", inference.WithMaxTokens(16)) {
|
||||
tokens = append(tokens, tok.Text)
|
||||
var tokenTexts []string
|
||||
for tok := range model.Generate(ctx, "The capital of France is", inference.WithMaxTokens(16)) {
|
||||
tokenTexts = append(tokenTexts, tok.Text)
|
||||
}
|
||||
|
||||
require.NoError(t, m.Err())
|
||||
require.NotEmpty(t, tokens, "expected at least one token")
|
||||
require.NoError(t, model.Err())
|
||||
require.NotEmpty(t, tokenTexts, "expected at least one token")
|
||||
|
||||
full := ""
|
||||
for _, tok := range tokens {
|
||||
full += tok
|
||||
fullText := ""
|
||||
for _, tok := range tokenTexts {
|
||||
fullText += tok
|
||||
}
|
||||
t.Logf("Generated: %s", full)
|
||||
t.Logf("Generated: %s", fullText)
|
||||
}
|
||||
|
||||
func TestROCm_Chat(t *testing.T) {
|
||||
func TestROCm_Chat_Good(t *testing.T) {
|
||||
skipIfNoROCm(t)
|
||||
skipIfNoModel(t)
|
||||
|
||||
b := &rocmBackend{}
|
||||
m, err := b.LoadModel(testModel, inference.WithContextLen(2048))
|
||||
backend := &rocmBackend{}
|
||||
model, err := backend.LoadModel(testModel, inference.WithContextLen(2048))
|
||||
require.NoError(t, err)
|
||||
defer m.Close()
|
||||
defer model.Close()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
|
@ -82,89 +82,89 @@ func TestROCm_Chat(t *testing.T) {
|
|||
{Role: "user", Content: "Say hello in exactly three words."},
|
||||
}
|
||||
|
||||
var tokens []string
|
||||
for tok := range m.Chat(ctx, messages, inference.WithMaxTokens(32)) {
|
||||
tokens = append(tokens, tok.Text)
|
||||
var tokenTexts []string
|
||||
for tok := range model.Chat(ctx, messages, inference.WithMaxTokens(32)) {
|
||||
tokenTexts = append(tokenTexts, tok.Text)
|
||||
}
|
||||
|
||||
require.NoError(t, m.Err())
|
||||
require.NotEmpty(t, tokens, "expected at least one token")
|
||||
require.NoError(t, model.Err())
|
||||
require.NotEmpty(t, tokenTexts, "expected at least one token")
|
||||
|
||||
full := ""
|
||||
for _, tok := range tokens {
|
||||
full += tok
|
||||
fullText := ""
|
||||
for _, tok := range tokenTexts {
|
||||
fullText += tok
|
||||
}
|
||||
t.Logf("Chat response: %s", full)
|
||||
t.Logf("Chat response: %s", fullText)
|
||||
}
|
||||
|
||||
func TestROCm_ContextCancellation(t *testing.T) {
|
||||
func TestROCm_ContextCancellation_Ugly(t *testing.T) {
|
||||
skipIfNoROCm(t)
|
||||
skipIfNoModel(t)
|
||||
|
||||
b := &rocmBackend{}
|
||||
m, err := b.LoadModel(testModel, inference.WithContextLen(2048))
|
||||
backend := &rocmBackend{}
|
||||
model, err := backend.LoadModel(testModel, inference.WithContextLen(2048))
|
||||
require.NoError(t, err)
|
||||
defer m.Close()
|
||||
defer model.Close()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
var count int
|
||||
for tok := range m.Generate(ctx, "Write a very long story about dragons", inference.WithMaxTokens(256)) {
|
||||
var tokenCount int
|
||||
for tok := range model.Generate(ctx, "Write a very long story about dragons", inference.WithMaxTokens(256)) {
|
||||
_ = tok
|
||||
count++
|
||||
if count >= 3 {
|
||||
tokenCount++
|
||||
if tokenCount >= 3 {
|
||||
cancel()
|
||||
}
|
||||
}
|
||||
|
||||
t.Logf("Got %d tokens before cancel", count)
|
||||
assert.GreaterOrEqual(t, count, 3)
|
||||
t.Logf("Got %d tokens before cancel", tokenCount)
|
||||
assert.GreaterOrEqual(t, tokenCount, 3)
|
||||
}
|
||||
|
||||
func TestROCm_GracefulShutdown(t *testing.T) {
|
||||
func TestROCm_GracefulShutdown_Good(t *testing.T) {
|
||||
skipIfNoROCm(t)
|
||||
skipIfNoModel(t)
|
||||
|
||||
b := &rocmBackend{}
|
||||
m, err := b.LoadModel(testModel, inference.WithContextLen(2048))
|
||||
backend := &rocmBackend{}
|
||||
model, err := backend.LoadModel(testModel, inference.WithContextLen(2048))
|
||||
require.NoError(t, err)
|
||||
defer m.Close()
|
||||
defer model.Close()
|
||||
|
||||
// Cancel mid-stream.
|
||||
ctx1, cancel1 := context.WithCancel(context.Background())
|
||||
var count1 int
|
||||
for tok := range m.Generate(ctx1, "Write a long story about space exploration", inference.WithMaxTokens(256)) {
|
||||
var firstTokenCount int
|
||||
for tok := range model.Generate(ctx1, "Write a long story about space exploration", inference.WithMaxTokens(256)) {
|
||||
_ = tok
|
||||
count1++
|
||||
if count1 >= 5 {
|
||||
firstTokenCount++
|
||||
if firstTokenCount >= 5 {
|
||||
cancel1()
|
||||
}
|
||||
}
|
||||
t.Logf("First generation: %d tokens before cancel", count1)
|
||||
t.Logf("First generation: %d tokens before cancel", firstTokenCount)
|
||||
|
||||
// Generate again on the same model — server should still be alive.
|
||||
ctx2, cancel2 := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel2()
|
||||
|
||||
var count2 int
|
||||
for tok := range m.Generate(ctx2, "The capital of France is", inference.WithMaxTokens(16)) {
|
||||
var secondTokenCount int
|
||||
for tok := range model.Generate(ctx2, "The capital of France is", inference.WithMaxTokens(16)) {
|
||||
_ = tok
|
||||
count2++
|
||||
secondTokenCount++
|
||||
}
|
||||
|
||||
require.NoError(t, m.Err())
|
||||
assert.Greater(t, count2, 0, "expected tokens from second generation after cancel")
|
||||
t.Logf("Second generation: %d tokens", count2)
|
||||
require.NoError(t, model.Err())
|
||||
assert.Greater(t, secondTokenCount, 0, "expected tokens from second generation after cancel")
|
||||
t.Logf("Second generation: %d tokens", secondTokenCount)
|
||||
}
|
||||
|
||||
func TestROCm_ConcurrentRequests(t *testing.T) {
|
||||
func TestROCm_ConcurrentRequests_Good(t *testing.T) {
|
||||
skipIfNoROCm(t)
|
||||
skipIfNoModel(t)
|
||||
|
||||
b := &rocmBackend{}
|
||||
m, err := b.LoadModel(testModel, inference.WithContextLen(2048))
|
||||
backend := &rocmBackend{}
|
||||
model, err := backend.LoadModel(testModel, inference.WithContextLen(2048))
|
||||
require.NoError(t, err)
|
||||
defer m.Close()
|
||||
defer model.Close()
|
||||
|
||||
const numGoroutines = 3
|
||||
results := make([]string, numGoroutines)
|
||||
|
|
@ -175,25 +175,25 @@ func TestROCm_ConcurrentRequests(t *testing.T) {
|
|||
"The capital of Italy is",
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numGoroutines)
|
||||
var waitGroup sync.WaitGroup
|
||||
waitGroup.Add(numGoroutines)
|
||||
|
||||
for i := range numGoroutines {
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
defer waitGroup.Done()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
var sb strings.Builder
|
||||
for tok := range m.Generate(ctx, prompts[idx], inference.WithMaxTokens(16)) {
|
||||
sb.WriteString(tok.Text)
|
||||
builder := core.NewBuilder()
|
||||
for tok := range model.Generate(ctx, prompts[idx], inference.WithMaxTokens(16)) {
|
||||
builder.WriteString(tok.Text)
|
||||
}
|
||||
results[idx] = sb.String()
|
||||
results[idx] = builder.String()
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
waitGroup.Wait()
|
||||
|
||||
for i, result := range results {
|
||||
t.Logf("Goroutine %d: %s", i, result)
|
||||
|
|
@ -201,14 +201,14 @@ func TestROCm_ConcurrentRequests(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestROCm_Classify(t *testing.T) {
|
||||
func TestROCm_Classify_Good(t *testing.T) {
|
||||
skipIfNoROCm(t)
|
||||
skipIfNoModel(t)
|
||||
|
||||
b := &rocmBackend{}
|
||||
m, err := b.LoadModel(testModel, inference.WithContextLen(2048))
|
||||
backend := &rocmBackend{}
|
||||
model, err := backend.LoadModel(testModel, inference.WithContextLen(2048))
|
||||
require.NoError(t, err)
|
||||
defer m.Close()
|
||||
defer model.Close()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
|
@ -218,24 +218,24 @@ func TestROCm_Classify(t *testing.T) {
|
|||
"2 + 2 =",
|
||||
}
|
||||
|
||||
results, err := m.Classify(ctx, prompts)
|
||||
results, err := model.Classify(ctx, prompts)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, results, 2)
|
||||
|
||||
for i, r := range results {
|
||||
assert.NotEmpty(t, r.Token.Text, "classify result %d should have a token", i)
|
||||
t.Logf("Classify %d: %q", i, r.Token.Text)
|
||||
for i, classifyResult := range results {
|
||||
assert.NotEmpty(t, classifyResult.Token.Text, "classify result %d should have a token", i)
|
||||
t.Logf("Classify %d: %q", i, classifyResult.Token.Text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestROCm_BatchGenerate(t *testing.T) {
|
||||
func TestROCm_BatchGenerate_Good(t *testing.T) {
|
||||
skipIfNoROCm(t)
|
||||
skipIfNoModel(t)
|
||||
|
||||
b := &rocmBackend{}
|
||||
m, err := b.LoadModel(testModel, inference.WithContextLen(2048))
|
||||
backend := &rocmBackend{}
|
||||
model, err := backend.LoadModel(testModel, inference.WithContextLen(2048))
|
||||
require.NoError(t, err)
|
||||
defer m.Close()
|
||||
defer model.Close()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
|
||||
defer cancel()
|
||||
|
|
@ -245,70 +245,70 @@ func TestROCm_BatchGenerate(t *testing.T) {
|
|||
"The capital of Germany is",
|
||||
}
|
||||
|
||||
results, err := m.BatchGenerate(ctx, prompts, inference.WithMaxTokens(8))
|
||||
results, err := model.BatchGenerate(ctx, prompts, inference.WithMaxTokens(8))
|
||||
require.NoError(t, err)
|
||||
require.Len(t, results, 2)
|
||||
|
||||
for i, r := range results {
|
||||
require.NoError(t, r.Err, "batch result %d error", i)
|
||||
assert.NotEmpty(t, r.Tokens, "batch result %d should have tokens", i)
|
||||
for i, batchResult := range results {
|
||||
require.NoError(t, batchResult.Err, "batch result %d error", i)
|
||||
assert.NotEmpty(t, batchResult.Tokens, "batch result %d should have tokens", i)
|
||||
|
||||
var sb strings.Builder
|
||||
for _, tok := range r.Tokens {
|
||||
sb.WriteString(tok.Text)
|
||||
builder := core.NewBuilder()
|
||||
for _, tok := range batchResult.Tokens {
|
||||
builder.WriteString(tok.Text)
|
||||
}
|
||||
t.Logf("Batch %d: %s", i, sb.String())
|
||||
t.Logf("Batch %d: %s", i, builder.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestROCm_InfoAndMetrics(t *testing.T) {
|
||||
func TestROCm_InfoAndMetrics_Good(t *testing.T) {
|
||||
skipIfNoROCm(t)
|
||||
skipIfNoModel(t)
|
||||
|
||||
b := &rocmBackend{}
|
||||
m, err := b.LoadModel(testModel, inference.WithContextLen(2048))
|
||||
backend := &rocmBackend{}
|
||||
model, err := backend.LoadModel(testModel, inference.WithContextLen(2048))
|
||||
require.NoError(t, err)
|
||||
defer m.Close()
|
||||
defer model.Close()
|
||||
|
||||
// Info should be populated from GGUF metadata.
|
||||
info := m.Info()
|
||||
assert.Equal(t, "gemma3", info.Architecture)
|
||||
assert.Greater(t, info.NumLayers, 0, "expected non-zero layer count")
|
||||
assert.Greater(t, info.QuantBits, 0, "expected non-zero quant bits")
|
||||
modelInfo := model.Info()
|
||||
assert.Equal(t, "gemma3", modelInfo.Architecture)
|
||||
assert.Greater(t, modelInfo.NumLayers, 0, "expected non-zero layer count")
|
||||
assert.Greater(t, modelInfo.QuantBits, 0, "expected non-zero quant bits")
|
||||
t.Logf("Info: arch=%s layers=%d quant=%d-bit group=%d",
|
||||
info.Architecture, info.NumLayers, info.QuantBits, info.QuantGroup)
|
||||
modelInfo.Architecture, modelInfo.NumLayers, modelInfo.QuantBits, modelInfo.QuantGroup)
|
||||
|
||||
// Generate some tokens to populate metrics.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
for range m.Generate(ctx, "Hello", inference.WithMaxTokens(4)) {
|
||||
for range model.Generate(ctx, "Hello", inference.WithMaxTokens(4)) {
|
||||
}
|
||||
require.NoError(t, m.Err())
|
||||
require.NoError(t, model.Err())
|
||||
|
||||
met := m.Metrics()
|
||||
assert.Greater(t, met.GeneratedTokens, 0, "expected generated tokens")
|
||||
assert.Greater(t, met.TotalDuration, time.Duration(0), "expected non-zero duration")
|
||||
assert.Greater(t, met.DecodeTokensPerSec, float64(0), "expected non-zero decode throughput")
|
||||
metrics := model.Metrics()
|
||||
assert.Greater(t, metrics.GeneratedTokens, 0, "expected generated tokens")
|
||||
assert.Greater(t, metrics.TotalDuration, time.Duration(0), "expected non-zero duration")
|
||||
assert.Greater(t, metrics.DecodeTokensPerSec, float64(0), "expected non-zero decode throughput")
|
||||
t.Logf("Metrics: gen=%d tok, total=%s, decode=%.1f tok/s, vram=%d MiB",
|
||||
met.GeneratedTokens, met.TotalDuration, met.DecodeTokensPerSec,
|
||||
met.ActiveMemoryBytes/(1024*1024))
|
||||
metrics.GeneratedTokens, metrics.TotalDuration, metrics.DecodeTokensPerSec,
|
||||
metrics.ActiveMemoryBytes/(1024*1024))
|
||||
}
|
||||
|
||||
func TestROCm_DiscoverModels(t *testing.T) {
|
||||
dir := filepath.Dir(testModel)
|
||||
if _, err := os.Stat(dir); err != nil {
|
||||
func TestROCm_DiscoverModels_Good(t *testing.T) {
|
||||
modelDir := core.PathDir(testModel)
|
||||
if _, err := os.Stat(modelDir); err != nil {
|
||||
t.Skip("model directory not available")
|
||||
}
|
||||
|
||||
models, err := DiscoverModels(dir)
|
||||
models, err := DiscoverModels(modelDir)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, models, "expected at least one model in %s", dir)
|
||||
require.NotEmpty(t, models, "expected at least one model in %s", modelDir)
|
||||
|
||||
for _, m := range models {
|
||||
t.Logf("Found: %s (%s %s %s, ctx=%d)", filepath.Base(m.Path), m.Architecture, m.Parameters, m.Quantisation, m.ContextLen)
|
||||
assert.NotEmpty(t, m.Architecture)
|
||||
assert.NotEmpty(t, m.Name)
|
||||
assert.Greater(t, m.FileSize, int64(0))
|
||||
for _, discoveredModel := range models {
|
||||
t.Logf("Found: %s (%s %s %s, ctx=%d)", core.PathBase(discoveredModel.Path), discoveredModel.Architecture, discoveredModel.Parameters, discoveredModel.Quantisation, discoveredModel.ContextLen)
|
||||
assert.NotEmpty(t, discoveredModel.Architecture)
|
||||
assert.NotEmpty(t, discoveredModel.Name)
|
||||
assert.Greater(t, discoveredModel.FileSize, int64(0))
|
||||
}
|
||||
}
|
||||
|
|
|
|||
62
server.go
62
server.go
|
|
@ -4,15 +4,15 @@ package rocm
|
|||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"dappco.re/go/core"
|
||||
|
||||
coreerr "forge.lthn.ai/core/go-log"
|
||||
"forge.lthn.ai/core/go-rocm/internal/llamacpp"
|
||||
)
|
||||
|
|
@ -38,28 +38,31 @@ func (s *server) alive() bool {
|
|||
|
||||
// findLlamaServer locates the llama-server binary.
|
||||
// Checks ROCM_LLAMA_SERVER_PATH first, then PATH.
|
||||
//
|
||||
// path, err := findLlamaServer()
|
||||
// // path == "/usr/local/bin/llama-server"
|
||||
func findLlamaServer() (string, error) {
|
||||
if p := os.Getenv("ROCM_LLAMA_SERVER_PATH"); p != "" {
|
||||
if _, err := os.Stat(p); err != nil {
|
||||
return "", coreerr.E("rocm.findLlamaServer", "llama-server not found at ROCM_LLAMA_SERVER_PATH="+p, err)
|
||||
if binaryPath := core.Env("ROCM_LLAMA_SERVER_PATH"); binaryPath != "" {
|
||||
if !(&core.Fs{}).New("/").Exists(binaryPath) {
|
||||
return "", coreerr.E("rocm.findLlamaServer", "llama-server not found at ROCM_LLAMA_SERVER_PATH="+binaryPath, nil)
|
||||
}
|
||||
return p, nil
|
||||
return binaryPath, nil
|
||||
}
|
||||
p, err := exec.LookPath("llama-server")
|
||||
binaryPath, err := exec.LookPath("llama-server")
|
||||
if err != nil {
|
||||
return "", coreerr.E("rocm.findLlamaServer", "llama-server not found in PATH", err)
|
||||
}
|
||||
return p, nil
|
||||
return binaryPath, nil
|
||||
}
|
||||
|
||||
// freePort asks the kernel for a free TCP port on localhost.
|
||||
func freePort() (int, error) {
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
return 0, coreerr.E("rocm.freePort", "listen for free port", err)
|
||||
}
|
||||
port := ln.Addr().(*net.TCPAddr).Port
|
||||
ln.Close()
|
||||
port := listener.Addr().(*net.TCPAddr).Port
|
||||
listener.Close()
|
||||
return port, nil
|
||||
}
|
||||
|
||||
|
|
@ -69,11 +72,11 @@ func freePort() (int, error) {
|
|||
func serverEnv() []string {
|
||||
environ := os.Environ()
|
||||
env := make([]string, 0, len(environ)+1)
|
||||
for _, e := range environ {
|
||||
if strings.HasPrefix(e, "HIP_VISIBLE_DEVICES=") {
|
||||
for _, envEntry := range environ {
|
||||
if core.HasPrefix(envEntry, "HIP_VISIBLE_DEVICES=") {
|
||||
continue
|
||||
}
|
||||
env = append(env, e)
|
||||
env = append(env, envEntry)
|
||||
}
|
||||
env = append(env, "HIP_VISIBLE_DEVICES=0")
|
||||
return env
|
||||
|
|
@ -82,7 +85,10 @@ func serverEnv() []string {
|
|||
// startServer spawns llama-server and waits for it to become ready.
|
||||
// It selects a free port automatically, retrying up to 3 times if the
|
||||
// process exits during startup (e.g. port conflict).
|
||||
func startServer(binary, modelPath string, gpuLayers, ctxSize, parallelSlots int) (*server, error) {
|
||||
//
|
||||
// s, err := startServer("/usr/local/bin/llama-server", "/data/model.gguf", 99, 4096, 4)
|
||||
// defer s.stop()
|
||||
func startServer(binary, modelPath string, gpuLayers, contextSize, parallelSlots int) (*server, error) {
|
||||
if gpuLayers < 0 {
|
||||
gpuLayers = 999
|
||||
}
|
||||
|
|
@ -102,8 +108,8 @@ func startServer(binary, modelPath string, gpuLayers, ctxSize, parallelSlots int
|
|||
"--port", strconv.Itoa(port),
|
||||
"--n-gpu-layers", strconv.Itoa(gpuLayers),
|
||||
}
|
||||
if ctxSize > 0 {
|
||||
args = append(args, "--ctx-size", strconv.Itoa(ctxSize))
|
||||
if contextSize > 0 {
|
||||
args = append(args, "--ctx-size", strconv.Itoa(contextSize))
|
||||
}
|
||||
if parallelSlots > 0 {
|
||||
args = append(args, "--parallel", strconv.Itoa(parallelSlots))
|
||||
|
|
@ -116,39 +122,39 @@ func startServer(binary, modelPath string, gpuLayers, ctxSize, parallelSlots int
|
|||
return nil, coreerr.E("rocm.startServer", "start llama-server", err)
|
||||
}
|
||||
|
||||
s := &server{
|
||||
subprocess := &server{
|
||||
cmd: cmd,
|
||||
port: port,
|
||||
client: llamacpp.NewClient(fmt.Sprintf("http://127.0.0.1:%d", port)),
|
||||
client: llamacpp.NewClient(core.Sprintf("http://127.0.0.1:%d", port)),
|
||||
exited: make(chan struct{}),
|
||||
}
|
||||
|
||||
go func() {
|
||||
s.exitErr = cmd.Wait()
|
||||
close(s.exited)
|
||||
subprocess.exitErr = cmd.Wait()
|
||||
close(subprocess.exited)
|
||||
}()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
|
||||
err = s.waitReady(ctx)
|
||||
err = subprocess.waitReady(ctx)
|
||||
cancel()
|
||||
if err == nil {
|
||||
return s, nil
|
||||
return subprocess, nil
|
||||
}
|
||||
|
||||
// Only retry if the process actually exited (e.g. port conflict).
|
||||
// A timeout means the server is stuck, not a port issue.
|
||||
select {
|
||||
case <-s.exited:
|
||||
_ = s.stop()
|
||||
lastErr = coreerr.E("rocm.startServer", fmt.Sprintf("attempt %d", attempt+1), err)
|
||||
case <-subprocess.exited:
|
||||
_ = subprocess.stop()
|
||||
lastErr = coreerr.E("rocm.startServer", core.Sprintf("attempt %d", attempt+1), err)
|
||||
continue
|
||||
default:
|
||||
_ = s.stop()
|
||||
_ = subprocess.stop()
|
||||
return nil, coreerr.E("rocm.startServer", "llama-server not ready", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil, coreerr.E("rocm.startServer", fmt.Sprintf("server failed after %d attempts", maxAttempts), lastErr)
|
||||
return nil, coreerr.E("rocm.startServer", core.Sprintf("server failed after %d attempts", maxAttempts), lastErr)
|
||||
}
|
||||
|
||||
// waitReady polls the health endpoint until the server is ready.
|
||||
|
|
|
|||
176
server_test.go
176
server_test.go
|
|
@ -4,10 +4,10 @@ package rocm
|
|||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"dappco.re/go/core"
|
||||
|
||||
"forge.lthn.ai/core/go-inference"
|
||||
coreerr "forge.lthn.ai/core/go-log"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
|
@ -34,6 +34,14 @@ func TestFindLlamaServer_Bad_EnvPathMissing(t *testing.T) {
|
|||
assert.ErrorContains(t, err, "not found")
|
||||
}
|
||||
|
||||
func TestFindLlamaServer_Ugly_EmptyPATH(t *testing.T) {
|
||||
// With no ROCM_LLAMA_SERVER_PATH set and an empty PATH, LookPath must fail.
|
||||
t.Setenv("ROCM_LLAMA_SERVER_PATH", "")
|
||||
t.Setenv("PATH", "")
|
||||
_, err := findLlamaServer()
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestFreePort_Good(t *testing.T) {
|
||||
port, err := freePort()
|
||||
require.NoError(t, err)
|
||||
|
|
@ -50,11 +58,26 @@ func TestFreePort_Good_UniquePerCall(t *testing.T) {
|
|||
_ = p2
|
||||
}
|
||||
|
||||
func TestFreePort_Bad_InvalidAddr(t *testing.T) {
|
||||
// freePort always binds to localhost so it can't fail on a valid machine;
|
||||
// this test documents that the port is always in the valid range.
|
||||
port, err := freePort()
|
||||
require.NoError(t, err)
|
||||
assert.Greater(t, port, 1023, "expected unprivileged port")
|
||||
}
|
||||
|
||||
func TestFreePort_Ugly_ReturnsUsablePort(t *testing.T) {
|
||||
// The returned port should be bindable a second time.
|
||||
port, err := freePort()
|
||||
require.NoError(t, err)
|
||||
assert.NotZero(t, port)
|
||||
}
|
||||
|
||||
func TestServerEnv_Good_SetsHIPVisibleDevices(t *testing.T) {
|
||||
env := serverEnv()
|
||||
var hipVals []string
|
||||
for _, e := range env {
|
||||
if strings.HasPrefix(e, "HIP_VISIBLE_DEVICES=") {
|
||||
if core.HasPrefix(e, "HIP_VISIBLE_DEVICES=") {
|
||||
hipVals = append(hipVals, e)
|
||||
}
|
||||
}
|
||||
|
|
@ -66,7 +89,27 @@ func TestServerEnv_Good_FiltersExistingHIPVisibleDevices(t *testing.T) {
|
|||
env := serverEnv()
|
||||
var hipVals []string
|
||||
for _, e := range env {
|
||||
if strings.HasPrefix(e, "HIP_VISIBLE_DEVICES=") {
|
||||
if core.HasPrefix(e, "HIP_VISIBLE_DEVICES=") {
|
||||
hipVals = append(hipVals, e)
|
||||
}
|
||||
}
|
||||
assert.Equal(t, []string{"HIP_VISIBLE_DEVICES=0"}, hipVals)
|
||||
}
|
||||
|
||||
func TestServerEnv_Bad_NilEnviron(t *testing.T) {
|
||||
// serverEnv must never panic even when called with unusual env state.
|
||||
// It always appends HIP_VISIBLE_DEVICES=0 regardless of ambient env.
|
||||
env := serverEnv()
|
||||
assert.NotEmpty(t, env)
|
||||
}
|
||||
|
||||
func TestServerEnv_Ugly_MultipleHIPEntries(t *testing.T) {
|
||||
// Even if multiple HIP_VISIBLE_DEVICES entries somehow existed, only one must remain.
|
||||
t.Setenv("HIP_VISIBLE_DEVICES", "2,3")
|
||||
env := serverEnv()
|
||||
var hipVals []string
|
||||
for _, e := range env {
|
||||
if core.HasPrefix(e, "HIP_VISIBLE_DEVICES=") {
|
||||
hipVals = append(hipVals, e)
|
||||
}
|
||||
}
|
||||
|
|
@ -75,12 +118,32 @@ func TestServerEnv_Good_FiltersExistingHIPVisibleDevices(t *testing.T) {
|
|||
|
||||
func TestAvailable_Good(t *testing.T) {
|
||||
b := &rocmBackend{}
|
||||
if _, err := os.Stat("/dev/kfd"); err != nil {
|
||||
if !(&core.Fs{}).New("/").Exists("/dev/kfd") {
|
||||
t.Skip("no ROCm hardware")
|
||||
}
|
||||
assert.True(t, b.Available())
|
||||
}
|
||||
|
||||
func TestAvailable_Bad_NoDevice(t *testing.T) {
|
||||
// When /dev/kfd is absent, Available must return false.
|
||||
if (&core.Fs{}).New("/").Exists("/dev/kfd") {
|
||||
t.Skip("ROCm device present — skip no-device bad path on this machine")
|
||||
}
|
||||
b := &rocmBackend{}
|
||||
assert.False(t, b.Available())
|
||||
}
|
||||
|
||||
func TestAvailable_Ugly_NoLlamaServer(t *testing.T) {
|
||||
// Even with /dev/kfd present, Available must be false if llama-server is missing.
|
||||
// We can't create /dev/kfd in a test, so verify the condition via findLlamaServer.
|
||||
t.Setenv("PATH", "")
|
||||
t.Setenv("ROCM_LLAMA_SERVER_PATH", "")
|
||||
b := &rocmBackend{}
|
||||
// If kfd is present but llama-server missing, Available returns false.
|
||||
// If kfd is absent, Available also returns false. Either way, not panic.
|
||||
_ = b.Available()
|
||||
}
|
||||
|
||||
func TestServerAlive_Good_Running(t *testing.T) {
|
||||
s := &server{exited: make(chan struct{})}
|
||||
assert.True(t, s.alive())
|
||||
|
|
@ -93,6 +156,39 @@ func TestServerAlive_Good_Exited(t *testing.T) {
|
|||
assert.False(t, s.alive())
|
||||
}
|
||||
|
||||
func TestServerAlive_Bad_NilExited(t *testing.T) {
|
||||
// A server with a nil exited channel panics — this documents the contract:
|
||||
// exited must always be initialised before use.
|
||||
// We test the well-formed bad state: exited closed with nil exitErr.
|
||||
exited := make(chan struct{})
|
||||
close(exited)
|
||||
s := &server{exited: exited, exitErr: nil}
|
||||
assert.False(t, s.alive())
|
||||
}
|
||||
|
||||
func TestServerAlive_Ugly_ExitedAfterStart(t *testing.T) {
|
||||
// alive transitions from true to false when the channel is closed.
|
||||
exited := make(chan struct{})
|
||||
s := &server{exited: exited}
|
||||
assert.True(t, s.alive())
|
||||
close(exited)
|
||||
assert.False(t, s.alive())
|
||||
}
|
||||
|
||||
func TestGenerate_Good_YieldsNoTokensOnEmptyServer(t *testing.T) {
|
||||
// A newly-created dead server produces zero tokens and records an error.
|
||||
exited := make(chan struct{})
|
||||
close(exited)
|
||||
s := &server{exited: exited, exitErr: nil}
|
||||
m := &rocmModel{server: s}
|
||||
|
||||
var tokenCount int
|
||||
for range m.Generate(context.Background(), "hello") {
|
||||
tokenCount++
|
||||
}
|
||||
assert.Equal(t, 0, tokenCount)
|
||||
}
|
||||
|
||||
func TestGenerate_Bad_ServerDead(t *testing.T) {
|
||||
exited := make(chan struct{})
|
||||
close(exited)
|
||||
|
|
@ -102,14 +198,39 @@ func TestGenerate_Bad_ServerDead(t *testing.T) {
|
|||
}
|
||||
m := &rocmModel{server: s}
|
||||
|
||||
var count int
|
||||
var tokenCount int
|
||||
for range m.Generate(context.Background(), "hello") {
|
||||
count++
|
||||
tokenCount++
|
||||
}
|
||||
assert.Equal(t, 0, count)
|
||||
assert.Equal(t, 0, tokenCount)
|
||||
assert.ErrorContains(t, m.Err(), "server has exited")
|
||||
}
|
||||
|
||||
func TestGenerate_Ugly_ErrClearedBetweenCalls(t *testing.T) {
|
||||
// Err is reset to nil on each Generate call start.
|
||||
exited := make(chan struct{})
|
||||
close(exited)
|
||||
s := &server{exited: exited, exitErr: coreerr.E("test", "killed", nil)}
|
||||
m := &rocmModel{server: s}
|
||||
|
||||
for range m.Generate(context.Background(), "first") {
|
||||
}
|
||||
assert.Error(t, m.Err())
|
||||
}
|
||||
|
||||
func TestStartServer_Good_RejectsNegativeLayers(t *testing.T) {
|
||||
// gpuLayers=-1 must be converted to 999 (all layers on GPU).
|
||||
// /bin/false exits immediately so we observe the retry behaviour.
|
||||
_, err := startServer("/bin/false", "/nonexistent/model.gguf", -1, 0, 0)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "failed after 3 attempts")
|
||||
}
|
||||
|
||||
func TestStartServer_Bad_BinaryNotFound(t *testing.T) {
|
||||
_, err := startServer("/nonexistent/binary", "/nonexistent/model.gguf", 0, 0, 0)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestStartServer_Ugly_RetriesOnProcessExit(t *testing.T) {
|
||||
// /bin/false starts successfully but exits immediately with code 1.
|
||||
// startServer should retry up to 3 times, then fail.
|
||||
|
|
@ -118,6 +239,20 @@ func TestStartServer_Ugly_RetriesOnProcessExit(t *testing.T) {
|
|||
assert.Contains(t, err.Error(), "failed after 3 attempts")
|
||||
}
|
||||
|
||||
func TestChat_Good_EmptyMessages(t *testing.T) {
|
||||
// Chat with an empty message list on a dead server yields no tokens.
|
||||
exited := make(chan struct{})
|
||||
close(exited)
|
||||
s := &server{exited: exited, exitErr: nil}
|
||||
m := &rocmModel{server: s}
|
||||
|
||||
var tokenCount int
|
||||
for range m.Chat(context.Background(), nil) {
|
||||
tokenCount++
|
||||
}
|
||||
assert.Equal(t, 0, tokenCount)
|
||||
}
|
||||
|
||||
func TestChat_Bad_ServerDead(t *testing.T) {
|
||||
exited := make(chan struct{})
|
||||
close(exited)
|
||||
|
|
@ -128,10 +263,29 @@ func TestChat_Bad_ServerDead(t *testing.T) {
|
|||
m := &rocmModel{server: s}
|
||||
|
||||
msgs := []inference.Message{{Role: "user", Content: "hello"}}
|
||||
var count int
|
||||
var tokenCount int
|
||||
for range m.Chat(context.Background(), msgs) {
|
||||
count++
|
||||
tokenCount++
|
||||
}
|
||||
assert.Equal(t, 0, count)
|
||||
assert.Equal(t, 0, tokenCount)
|
||||
assert.ErrorContains(t, m.Err(), "server has exited")
|
||||
}
|
||||
|
||||
func TestChat_Ugly_MultipleRolesOnDeadServer(t *testing.T) {
|
||||
// Chat with multiple roles on a dead server must still return safely.
|
||||
exited := make(chan struct{})
|
||||
close(exited)
|
||||
s := &server{exited: exited, exitErr: coreerr.E("test", "killed", nil)}
|
||||
m := &rocmModel{server: s}
|
||||
|
||||
msgs := []inference.Message{
|
||||
{Role: "system", Content: "You are helpful."},
|
||||
{Role: "user", Content: "Hello"},
|
||||
}
|
||||
var tokenCount int
|
||||
for range m.Chat(context.Background(), msgs) {
|
||||
tokenCount++
|
||||
}
|
||||
assert.Equal(t, 0, tokenCount)
|
||||
assert.Error(t, m.Err())
|
||||
}
|
||||
|
|
|
|||
34
vram.go
34
vram.go
|
|
@ -3,10 +3,9 @@
|
|||
package rocm
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"dappco.re/go/core"
|
||||
|
||||
coreerr "forge.lthn.ai/core/go-log"
|
||||
)
|
||||
|
|
@ -15,17 +14,10 @@ import (
|
|||
// It identifies the dGPU by selecting the card with the largest VRAM total,
|
||||
// which avoids hardcoding card numbers (e.g. card0=iGPU, card1=dGPU on Ryzen).
|
||||
//
|
||||
// Note: total and used are read non-atomically from sysfs; transient
|
||||
// inconsistencies are possible under heavy allocation churn.
|
||||
//
|
||||
// info, err := rocm.GetVRAMInfo()
|
||||
// fmt.Printf("VRAM: %d MiB used / %d MiB total (free: %d MiB)",
|
||||
// info.Used/(1024*1024), info.Total/(1024*1024), info.Free/(1024*1024))
|
||||
// // info.Total == 17179869184, info.Used == 2147483648, info.Free == 15032385536
|
||||
func GetVRAMInfo() (VRAMInfo, error) {
|
||||
cards, err := filepath.Glob("/sys/class/drm/card[0-9]*/device/mem_info_vram_total")
|
||||
if err != nil {
|
||||
return VRAMInfo{}, coreerr.E("rocm.GetVRAMInfo", "glob vram sysfs", err)
|
||||
}
|
||||
cards := core.PathGlob("/sys/class/drm/card[0-9]*/device/mem_info_vram_total")
|
||||
if len(cards) == 0 {
|
||||
return VRAMInfo{}, coreerr.E("rocm.GetVRAMInfo", "no GPU VRAM info found in sysfs", nil)
|
||||
}
|
||||
|
|
@ -40,7 +32,7 @@ func GetVRAMInfo() (VRAMInfo, error) {
|
|||
}
|
||||
if total > bestTotal {
|
||||
bestTotal = total
|
||||
bestDir = filepath.Dir(totalPath)
|
||||
bestDir = core.PathDir(totalPath)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -48,27 +40,27 @@ func GetVRAMInfo() (VRAMInfo, error) {
|
|||
return VRAMInfo{}, coreerr.E("rocm.GetVRAMInfo", "no readable VRAM sysfs entries", nil)
|
||||
}
|
||||
|
||||
used, err := readSysfsUint64(filepath.Join(bestDir, "mem_info_vram_used"))
|
||||
used, err := readSysfsUint64(core.Path(bestDir, "mem_info_vram_used"))
|
||||
if err != nil {
|
||||
return VRAMInfo{}, coreerr.E("rocm.GetVRAMInfo", "read vram used", err)
|
||||
}
|
||||
|
||||
free := uint64(0)
|
||||
freeBytes := uint64(0)
|
||||
if bestTotal > used {
|
||||
free = bestTotal - used
|
||||
freeBytes = bestTotal - used
|
||||
}
|
||||
|
||||
return VRAMInfo{
|
||||
Total: bestTotal,
|
||||
Used: used,
|
||||
Free: free,
|
||||
Free: freeBytes,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func readSysfsUint64(path string) (uint64, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
result := (&core.Fs{}).New("/").Read(path)
|
||||
if !result.OK {
|
||||
return 0, coreerr.E("rocm.readSysfsUint64", "read sysfs file", result.Value.(error))
|
||||
}
|
||||
return strconv.ParseUint(strings.TrimSpace(string(data)), 10, 64)
|
||||
return strconv.ParseUint(core.Trim(result.Value.(string)), 10, 64)
|
||||
}
|
||||
|
|
|
|||
40
vram_test.go
40
vram_test.go
|
|
@ -4,16 +4,17 @@ package rocm
|
|||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"dappco.re/go/core"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestReadSysfsUint64_Good(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "test_value")
|
||||
path := core.Path(dir, "test_value")
|
||||
require.NoError(t, os.WriteFile(path, []byte("17163091968\n"), 0644))
|
||||
|
||||
val, err := readSysfsUint64(path)
|
||||
|
|
@ -28,7 +29,7 @@ func TestReadSysfsUint64_Bad_NotFound(t *testing.T) {
|
|||
|
||||
func TestReadSysfsUint64_Bad_InvalidContent(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "bad_value")
|
||||
path := core.Path(dir, "bad_value")
|
||||
require.NoError(t, os.WriteFile(path, []byte("not-a-number\n"), 0644))
|
||||
|
||||
_, err := readSysfsUint64(path)
|
||||
|
|
@ -37,13 +38,24 @@ func TestReadSysfsUint64_Bad_InvalidContent(t *testing.T) {
|
|||
|
||||
func TestReadSysfsUint64_Bad_EmptyFile(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "empty_value")
|
||||
path := core.Path(dir, "empty_value")
|
||||
require.NoError(t, os.WriteFile(path, []byte(""), 0644))
|
||||
|
||||
_, err := readSysfsUint64(path)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestReadSysfsUint64_Ugly_WhitespaceAroundValue(t *testing.T) {
|
||||
// sysfs files often contain a trailing newline; readSysfsUint64 must trim it.
|
||||
dir := t.TempDir()
|
||||
path := core.Path(dir, "whitespace_value")
|
||||
require.NoError(t, os.WriteFile(path, []byte(" 17163091968 \n"), 0644))
|
||||
|
||||
val, err := readSysfsUint64(path)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, uint64(17163091968), val)
|
||||
}
|
||||
|
||||
func TestGetVRAMInfo_Good(t *testing.T) {
|
||||
info, err := GetVRAMInfo()
|
||||
if err != nil {
|
||||
|
|
@ -55,3 +67,23 @@ func TestGetVRAMInfo_Good(t *testing.T) {
|
|||
assert.Greater(t, info.Used, uint64(0), "expected some VRAM in use")
|
||||
assert.Equal(t, info.Total-info.Used, info.Free, "Free should equal Total-Used")
|
||||
}
|
||||
|
||||
func TestGetVRAMInfo_Bad_NoGPU(t *testing.T) {
|
||||
// On non-ROCm machines GetVRAMInfo must return an error, not panic.
|
||||
_, err := GetVRAMInfo()
|
||||
if err == nil {
|
||||
t.Skip("ROCm sysfs present — not applicable on this machine")
|
||||
}
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestGetVRAMInfo_Ugly_FreeComputed(t *testing.T) {
|
||||
// Invariant: Free == Total - Used whenever Total >= Used.
|
||||
info, err := GetVRAMInfo()
|
||||
if err != nil {
|
||||
t.Skipf("no VRAM sysfs info available: %v", err)
|
||||
}
|
||||
if info.Total >= info.Used {
|
||||
assert.Equal(t, info.Total-info.Used, info.Free)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue