Compare commits
1 commit
main
...
ax/review-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
41b34b6779 |
18 changed files with 174 additions and 114 deletions
17
backend.go
17
backend.go
|
|
@ -18,6 +18,9 @@ func (b *rocmBackend) Name() string { return "rocm" }
|
|||
|
||||
// Available reports whether ROCm GPU inference can run on this machine.
|
||||
// Checks for the ROCm kernel driver (/dev/kfd) and a findable llama-server binary.
|
||||
//
|
||||
// b := inference.FindBackend("rocm")
|
||||
// if b.Available() { /* safe to LoadModel */ }
|
||||
func (b *rocmBackend) Available() bool {
|
||||
if _, err := os.Stat("/dev/kfd"); err != nil {
|
||||
return false
|
||||
|
|
@ -32,8 +35,14 @@ func (b *rocmBackend) Available() bool {
|
|||
// Model architecture is read from GGUF metadata (replacing filename-based guessing).
|
||||
// If no context length is specified, defaults to min(model_context_length, 4096)
|
||||
// to prevent VRAM exhaustion on models with 128K+ native context.
|
||||
//
|
||||
// m, err := backend.LoadModel("/data/lem/gguf/model.gguf",
|
||||
// inference.WithContextLen(4096),
|
||||
// inference.WithGPULayers(-1),
|
||||
// )
|
||||
// defer m.Close()
|
||||
func (b *rocmBackend) LoadModel(path string, opts ...inference.LoadOption) (inference.TextModel, error) {
|
||||
cfg := inference.ApplyLoadOpts(opts)
|
||||
config := inference.ApplyLoadOpts(opts)
|
||||
|
||||
binary, err := findLlamaServer()
|
||||
if err != nil {
|
||||
|
|
@ -45,12 +54,12 @@ func (b *rocmBackend) LoadModel(path string, opts ...inference.LoadOption) (infe
|
|||
return nil, coreerr.E("rocm.LoadModel", "read model metadata", err)
|
||||
}
|
||||
|
||||
ctxLen := cfg.ContextLen
|
||||
ctxLen := config.ContextLen
|
||||
if ctxLen == 0 && meta.ContextLength > 0 {
|
||||
ctxLen = int(min(meta.ContextLength, 4096))
|
||||
}
|
||||
|
||||
srv, err := startServer(binary, path, cfg.GPULayers, ctxLen, cfg.ParallelSlots)
|
||||
srv, err := startServer(binary, path, config.GPULayers, ctxLen, config.ParallelSlots)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -85,7 +94,7 @@ func (b *rocmBackend) LoadModel(path string, opts ...inference.LoadOption) (infe
|
|||
}
|
||||
|
||||
return &rocmModel{
|
||||
srv: srv,
|
||||
server: srv,
|
||||
modelType: meta.Architecture,
|
||||
modelInfo: inference.ModelInfo{
|
||||
Architecture: meta.Architecture,
|
||||
|
|
|
|||
|
|
@ -6,8 +6,13 @@ import (
|
|||
"forge.lthn.ai/core/go-rocm/internal/gguf"
|
||||
)
|
||||
|
||||
// DiscoverModels scans a directory for GGUF model files and returns
|
||||
// structured information about each. Files that cannot be parsed are skipped.
|
||||
// DiscoverModels scans a directory for GGUF model files.
|
||||
// Files that cannot be parsed are silently skipped.
|
||||
//
|
||||
// models, err := rocm.DiscoverModels("/data/lem/gguf")
|
||||
// for _, m := range models {
|
||||
// fmt.Printf("%s: %s %s ctx=%d\n", m.Name, m.Architecture, m.Quantisation, m.ContextLen)
|
||||
// }
|
||||
func DiscoverModels(dir string) ([]ModelInfo, error) {
|
||||
matches, err := filepath.Glob(filepath.Join(dir, "*.gguf"))
|
||||
if err != nil {
|
||||
|
|
|
|||
|
|
@ -63,7 +63,7 @@ func writeDiscoverKV(t *testing.T, f *os.File, key string, val any) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestDiscoverModels(t *testing.T) {
|
||||
func TestDiscoverModels_Good(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
// Create two valid GGUF model files.
|
||||
|
|
@ -112,7 +112,7 @@ func TestDiscoverModels(t *testing.T) {
|
|||
assert.Greater(t, llama.FileSize, int64(0))
|
||||
}
|
||||
|
||||
func TestDiscoverModels_EmptyDir(t *testing.T) {
|
||||
func TestDiscoverModels_Good_EmptyDir(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
models, err := DiscoverModels(dir)
|
||||
|
|
@ -120,7 +120,7 @@ func TestDiscoverModels_EmptyDir(t *testing.T) {
|
|||
assert.Empty(t, models)
|
||||
}
|
||||
|
||||
func TestDiscoverModels_NotFound(t *testing.T) {
|
||||
func TestDiscoverModels_Good_NonExistentDir(t *testing.T) {
|
||||
// filepath.Glob returns nil, nil for a pattern matching no files,
|
||||
// even when the directory does not exist.
|
||||
models, err := DiscoverModels("/nonexistent/dir")
|
||||
|
|
@ -128,7 +128,7 @@ func TestDiscoverModels_NotFound(t *testing.T) {
|
|||
assert.Empty(t, models)
|
||||
}
|
||||
|
||||
func TestDiscoverModels_SkipsCorruptFile(t *testing.T) {
|
||||
func TestDiscoverModels_Ugly_SkipsCorruptFile(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
// Create a valid GGUF file.
|
||||
|
|
|
|||
22
go.mod
22
go.mod
|
|
@ -1,26 +1,22 @@
|
|||
module dappco.re/go/core/rocm
|
||||
module forge.lthn.ai/core/go-rocm
|
||||
|
||||
go 1.26.0
|
||||
|
||||
require (
|
||||
dappco.re/go/core/inference v0.1.5
|
||||
dappco.re/go/core/log v0.0.4
|
||||
forge.lthn.ai/core/go-inference v0.1.5
|
||||
forge.lthn.ai/core/go-log v0.0.4
|
||||
)
|
||||
|
||||
require github.com/kr/text v0.2.0 // indirect
|
||||
require (
|
||||
dappco.re/go/core v0.8.0-alpha.1 // indirect
|
||||
github.com/kr/text v0.2.0 // indirect
|
||||
)
|
||||
|
||||
require (
|
||||
dappco.re/go/core v0.5.0
|
||||
dappco.re/go/core/api v0.2.0
|
||||
dappco.re/go/core/i18n v0.2.0
|
||||
dappco.re/go/core/io v0.2.0
|
||||
dappco.re/go/core/log v0.1.0
|
||||
dappco.re/go/core/process v0.3.0
|
||||
dappco.re/go/core/scm v0.4.0
|
||||
dappco.re/go/core/store v0.2.0
|
||||
dappco.re/go/core/ws v0.3.0
|
||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
|
||||
github.com/stretchr/testify v1.11.1
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
||||
replace forge.lthn.ai/core/go-inference => ../go-inference
|
||||
|
|
|
|||
2
go.sum
2
go.sum
|
|
@ -1,3 +1,5 @@
|
|||
dappco.re/go/core v0.8.0-alpha.1 h1:gj7+Scv+L63Z7wMxbJYHhaRFkHJo2u4MMPuUSv/Dhtk=
|
||||
dappco.re/go/core v0.8.0-alpha.1/go.mod h1:f2/tBZ3+3IqDrg2F5F598llv0nmb/4gJVCFzM5geE4A=
|
||||
forge.lthn.ai/core/go-log v0.0.4 h1:KTuCEPgFmuM8KJfnyQ8vPOU1Jg654W74h8IJvfQMfv0=
|
||||
forge.lthn.ai/core/go-log v0.0.4/go.mod h1:r14MXKOD3LF/sI8XUJQhRk/SZHBE7jAFVuCfgkXoZPw=
|
||||
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
|
||||
|
|
|
|||
|
|
@ -72,6 +72,10 @@ var fileTypeNames = map[uint32]string{
|
|||
|
||||
// FileTypeName returns a human-readable name for a GGML quantisation file type.
|
||||
// Unknown types return "type_N" where N is the numeric value.
|
||||
//
|
||||
// gguf.FileTypeName(15) // "Q4_K_M"
|
||||
// gguf.FileTypeName(7) // "Q8_0"
|
||||
// gguf.FileTypeName(1) // "F16"
|
||||
func FileTypeName(ft uint32) string {
|
||||
if name, ok := fileTypeNames[ft]; ok {
|
||||
return name
|
||||
|
|
@ -81,6 +85,9 @@ func FileTypeName(ft uint32) string {
|
|||
|
||||
// ReadMetadata reads the GGUF header from the file at path and returns the
|
||||
// extracted metadata. Only metadata KV pairs are read; tensor data is not loaded.
|
||||
//
|
||||
// meta, err := gguf.ReadMetadata("/data/model.gguf")
|
||||
// fmt.Printf("arch=%s ctx=%d quant=%s", meta.Architecture, meta.ContextLength, gguf.FileTypeName(meta.FileType))
|
||||
func ReadMetadata(path string) (Metadata, error) {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
|
|
@ -123,15 +130,15 @@ func ReadMetadata(path string) (Metadata, error) {
|
|||
return Metadata{}, coreerr.E("gguf.ReadMetadata", "reading kv count", err)
|
||||
}
|
||||
} else {
|
||||
var tc, kc uint32
|
||||
if err := binary.Read(r, binary.LittleEndian, &tc); err != nil {
|
||||
var tensorCount32, kvCount32 uint32
|
||||
if err := binary.Read(r, binary.LittleEndian, &tensorCount32); err != nil {
|
||||
return Metadata{}, coreerr.E("gguf.ReadMetadata", "reading tensor count", err)
|
||||
}
|
||||
if err := binary.Read(r, binary.LittleEndian, &kc); err != nil {
|
||||
if err := binary.Read(r, binary.LittleEndian, &kvCount32); err != nil {
|
||||
return Metadata{}, coreerr.E("gguf.ReadMetadata", "reading kv count", err)
|
||||
}
|
||||
tensorCount = uint64(tc)
|
||||
kvCount = uint64(kc)
|
||||
tensorCount = uint64(tensorCount32)
|
||||
kvCount = uint64(kvCount32)
|
||||
}
|
||||
_ = tensorCount // we only read metadata KVs
|
||||
|
||||
|
|
|
|||
|
|
@ -109,7 +109,7 @@ func writeTestGGUFV2(t *testing.T, kvs [][2]any) string {
|
|||
return path
|
||||
}
|
||||
|
||||
func TestReadMetadata_Gemma3(t *testing.T) {
|
||||
func TestReadMetadata_Good_Gemma3(t *testing.T) {
|
||||
path := writeTestGGUFOrdered(t, [][2]any{
|
||||
{"general.architecture", "gemma3"},
|
||||
{"general.name", "Test Gemma3 1B"},
|
||||
|
|
@ -131,7 +131,7 @@ func TestReadMetadata_Gemma3(t *testing.T) {
|
|||
assert.Greater(t, m.FileSize, int64(0))
|
||||
}
|
||||
|
||||
func TestReadMetadata_Llama(t *testing.T) {
|
||||
func TestReadMetadata_Good_Llama(t *testing.T) {
|
||||
path := writeTestGGUFOrdered(t, [][2]any{
|
||||
{"general.architecture", "llama"},
|
||||
{"general.name", "Test Llama 8B"},
|
||||
|
|
@ -153,7 +153,7 @@ func TestReadMetadata_Llama(t *testing.T) {
|
|||
assert.Greater(t, m.FileSize, int64(0))
|
||||
}
|
||||
|
||||
func TestReadMetadata_ArchAfterContextLength(t *testing.T) {
|
||||
func TestReadMetadata_Ugly_ArchAfterContextLength(t *testing.T) {
|
||||
// Architecture key comes AFTER the arch-specific keys.
|
||||
// The parser must handle deferred resolution of arch-prefixed keys.
|
||||
path := writeTestGGUFOrdered(t, [][2]any{
|
||||
|
|
@ -174,7 +174,7 @@ func TestReadMetadata_ArchAfterContextLength(t *testing.T) {
|
|||
assert.Equal(t, uint32(32), m.BlockCount)
|
||||
}
|
||||
|
||||
func TestReadMetadata_InvalidMagic(t *testing.T) {
|
||||
func TestReadMetadata_Bad_InvalidMagic(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "notgguf.bin")
|
||||
|
||||
|
|
@ -186,12 +186,12 @@ func TestReadMetadata_InvalidMagic(t *testing.T) {
|
|||
assert.Contains(t, err.Error(), "invalid magic")
|
||||
}
|
||||
|
||||
func TestReadMetadata_FileNotFound(t *testing.T) {
|
||||
func TestReadMetadata_Bad_FileNotFound(t *testing.T) {
|
||||
_, err := ReadMetadata("/nonexistent/path/model.gguf")
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestFileTypeName(t *testing.T) {
|
||||
func TestFileTypeName_Good(t *testing.T) {
|
||||
assert.Equal(t, "Q4_K_M", FileTypeName(15))
|
||||
assert.Equal(t, "Q5_K_M", FileTypeName(17))
|
||||
assert.Equal(t, "Q8_0", FileTypeName(7))
|
||||
|
|
@ -199,7 +199,7 @@ func TestFileTypeName(t *testing.T) {
|
|||
assert.Equal(t, "type_999", FileTypeName(999))
|
||||
}
|
||||
|
||||
func TestReadMetadata_V2(t *testing.T) {
|
||||
func TestReadMetadata_Good_V2(t *testing.T) {
|
||||
// GGUF v2 uses uint32 for tensor and KV counts (instead of uint64 in v3).
|
||||
path := writeTestGGUFV2(t, [][2]any{
|
||||
{"general.architecture", "llama"},
|
||||
|
|
@ -219,7 +219,7 @@ func TestReadMetadata_V2(t *testing.T) {
|
|||
assert.Equal(t, uint32(16), m.BlockCount)
|
||||
}
|
||||
|
||||
func TestReadMetadata_UnsupportedVersion(t *testing.T) {
|
||||
func TestReadMetadata_Bad_UnsupportedVersion(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "bad_version.gguf")
|
||||
|
||||
|
|
@ -235,7 +235,7 @@ func TestReadMetadata_UnsupportedVersion(t *testing.T) {
|
|||
assert.Contains(t, err.Error(), "unsupported GGUF version")
|
||||
}
|
||||
|
||||
func TestReadMetadata_SkipsUnknownValueTypes(t *testing.T) {
|
||||
func TestReadMetadata_Ugly_SkipsUnknownValueTypes(t *testing.T) {
|
||||
// Tests skipValue for uint8, int16, float32, uint64, bool, and array types.
|
||||
// These are stored under uninteresting keys so ReadMetadata skips them.
|
||||
dir := t.TempDir()
|
||||
|
|
@ -283,7 +283,7 @@ func TestReadMetadata_SkipsUnknownValueTypes(t *testing.T) {
|
|||
b8 := make([]byte, 8)
|
||||
binary.LittleEndian.PutUint64(b8, 3) // count: 3
|
||||
arrBuf = append(arrBuf, b8...)
|
||||
arrBuf = append(arrBuf, 10, 20, 30) // 3 uint8 values
|
||||
arrBuf = append(arrBuf, 10, 20, 30) // 3 uint8 values
|
||||
writeRawKV(t, f, "custom.array_val", 9, arrBuf)
|
||||
|
||||
// 7-8. Interesting keys to verify parsing continued correctly.
|
||||
|
|
@ -299,7 +299,7 @@ func TestReadMetadata_SkipsUnknownValueTypes(t *testing.T) {
|
|||
assert.Equal(t, "Skip Test Model", m.Name)
|
||||
}
|
||||
|
||||
func TestReadMetadata_Uint64ContextLength(t *testing.T) {
|
||||
func TestReadMetadata_Ugly_Uint64ContextLength(t *testing.T) {
|
||||
// context_length stored as uint64 that fits in uint32 — readTypedValue
|
||||
// should downcast it to uint32.
|
||||
path := writeTestGGUFOrdered(t, [][2]any{
|
||||
|
|
@ -315,7 +315,7 @@ func TestReadMetadata_Uint64ContextLength(t *testing.T) {
|
|||
assert.Equal(t, uint32(32), m.BlockCount)
|
||||
}
|
||||
|
||||
func TestReadMetadata_TruncatedFile(t *testing.T) {
|
||||
func TestReadMetadata_Bad_TruncatedFile(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "truncated.gguf")
|
||||
|
||||
|
|
@ -330,7 +330,7 @@ func TestReadMetadata_TruncatedFile(t *testing.T) {
|
|||
assert.Contains(t, err.Error(), "reading version")
|
||||
}
|
||||
|
||||
func TestReadMetadata_SkipsStringValue(t *testing.T) {
|
||||
func TestReadMetadata_Ugly_SkipsStringValue(t *testing.T) {
|
||||
// Tests skipValue for string type (type 8) on an uninteresting key.
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "skip_string.gguf")
|
||||
|
|
|
|||
|
|
@ -60,8 +60,14 @@ type completionChunkResponse struct {
|
|||
}
|
||||
|
||||
// ChatComplete sends a streaming chat completion request to /v1/chat/completions.
|
||||
// It returns an iterator over text chunks and a function that returns any error
|
||||
// that occurred during the request or while reading the stream.
|
||||
// Returns an iterator over text chunks and an error accessor called after ranging.
|
||||
//
|
||||
// chunks, errFn := client.ChatComplete(ctx, llamacpp.ChatRequest{
|
||||
// Messages: []llamacpp.ChatMessage{{Role: "user", Content: "Hello"}},
|
||||
// MaxTokens: 128, Temperature: 0.7,
|
||||
// })
|
||||
// for text := range chunks { fmt.Print(text) }
|
||||
// if err := errFn(); err != nil { /* handle */ }
|
||||
func (c *Client) ChatComplete(ctx context.Context, req ChatRequest) (iter.Seq[string], func() error) {
|
||||
req.Stream = true
|
||||
|
||||
|
|
@ -125,8 +131,13 @@ func (c *Client) ChatComplete(ctx context.Context, req ChatRequest) (iter.Seq[st
|
|||
}
|
||||
|
||||
// Complete sends a streaming completion request to /v1/completions.
|
||||
// It returns an iterator over text chunks and a function that returns any error
|
||||
// that occurred during the request or while reading the stream.
|
||||
// Returns an iterator over text chunks and an error accessor called after ranging.
|
||||
//
|
||||
// chunks, errFn := client.Complete(ctx, llamacpp.CompletionRequest{
|
||||
// Prompt: "The capital of France is", MaxTokens: 16, Temperature: 0.0,
|
||||
// })
|
||||
// for text := range chunks { fmt.Print(text) }
|
||||
// if err := errFn(); err != nil { /* handle */ }
|
||||
func (c *Client) Complete(ctx context.Context, req CompletionRequest) (iter.Seq[string], func() error) {
|
||||
req.Stream = true
|
||||
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ func sseLines(w http.ResponseWriter, lines []string) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestChatComplete_Streaming(t *testing.T) {
|
||||
func TestChatComplete_Good_Streaming(t *testing.T) {
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "/v1/chat/completions", r.URL.Path)
|
||||
assert.Equal(t, "POST", r.Method)
|
||||
|
|
@ -56,7 +56,7 @@ func TestChatComplete_Streaming(t *testing.T) {
|
|||
assert.Equal(t, []string{"Hello", " world"}, got)
|
||||
}
|
||||
|
||||
func TestChatComplete_EmptyResponse(t *testing.T) {
|
||||
func TestChatComplete_Good_EmptyResponse(t *testing.T) {
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
sseLines(w, []string{"[DONE]"})
|
||||
}))
|
||||
|
|
@ -77,7 +77,7 @@ func TestChatComplete_EmptyResponse(t *testing.T) {
|
|||
assert.Empty(t, got)
|
||||
}
|
||||
|
||||
func TestChatComplete_HTTPError(t *testing.T) {
|
||||
func TestChatComplete_Bad_HTTPError(t *testing.T) {
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
http.Error(w, "internal server error", http.StatusInternalServerError)
|
||||
}))
|
||||
|
|
@ -100,7 +100,7 @@ func TestChatComplete_HTTPError(t *testing.T) {
|
|||
assert.Contains(t, err.Error(), "500")
|
||||
}
|
||||
|
||||
func TestChatComplete_ContextCancelled(t *testing.T) {
|
||||
func TestChatComplete_Ugly_ContextCancelled(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
|
|
@ -140,7 +140,7 @@ func TestChatComplete_ContextCancelled(t *testing.T) {
|
|||
assert.Equal(t, []string{"Hello"}, got)
|
||||
}
|
||||
|
||||
func TestComplete_Streaming(t *testing.T) {
|
||||
func TestComplete_Good_Streaming(t *testing.T) {
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "/v1/completions", r.URL.Path)
|
||||
assert.Equal(t, "POST", r.Method)
|
||||
|
|
@ -170,7 +170,7 @@ func TestComplete_Streaming(t *testing.T) {
|
|||
assert.Equal(t, []string{"Once", " upon", " a time"}, got)
|
||||
}
|
||||
|
||||
func TestComplete_HTTPError(t *testing.T) {
|
||||
func TestComplete_Bad_HTTPError(t *testing.T) {
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
http.Error(w, "bad request", http.StatusBadRequest)
|
||||
}))
|
||||
|
|
|
|||
|
|
@ -18,6 +18,11 @@ type Client struct {
|
|||
}
|
||||
|
||||
// NewClient creates a client for the llama-server at the given base URL.
|
||||
//
|
||||
// client := llamacpp.NewClient("http://127.0.0.1:8080")
|
||||
// if err := client.Health(ctx); err != nil {
|
||||
// // server not ready
|
||||
// }
|
||||
func NewClient(baseURL string) *Client {
|
||||
return &Client{
|
||||
baseURL: strings.TrimRight(baseURL, "/"),
|
||||
|
|
@ -30,6 +35,11 @@ type healthResponse struct {
|
|||
}
|
||||
|
||||
// Health checks whether the llama-server is ready to accept requests.
|
||||
// Returns nil when the server responds with {"status":"ok"}.
|
||||
//
|
||||
// if err := client.Health(ctx); err != nil {
|
||||
// log.Printf("server not ready: %v", err)
|
||||
// }
|
||||
func (c *Client) Health(ctx context.Context) error {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.baseURL+"/health", nil)
|
||||
if err != nil {
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ import (
|
|||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestHealth_OK(t *testing.T) {
|
||||
func TestHealth_Good(t *testing.T) {
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "/health", r.URL.Path)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
|
@ -23,7 +23,7 @@ func TestHealth_OK(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestHealth_NotReady(t *testing.T) {
|
||||
func TestHealth_Bad_NotReady(t *testing.T) {
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Write([]byte(`{"status":"loading model"}`))
|
||||
|
|
@ -35,7 +35,7 @@ func TestHealth_NotReady(t *testing.T) {
|
|||
assert.ErrorContains(t, err, "not ready")
|
||||
}
|
||||
|
||||
func TestHealth_Loading(t *testing.T) {
|
||||
func TestHealth_Bad_Loading(t *testing.T) {
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
|
|
@ -48,7 +48,7 @@ func TestHealth_Loading(t *testing.T) {
|
|||
assert.ErrorContains(t, err, "503")
|
||||
}
|
||||
|
||||
func TestHealth_ServerDown(t *testing.T) {
|
||||
func TestHealth_Bad_ServerDown(t *testing.T) {
|
||||
c := NewClient("http://127.0.0.1:1") // nothing listening
|
||||
err := c.Health(context.Background())
|
||||
assert.Error(t, err)
|
||||
|
|
|
|||
72
model.go
72
model.go
|
|
@ -17,7 +17,7 @@ import (
|
|||
|
||||
// rocmModel implements inference.TextModel using a llama-server subprocess.
|
||||
type rocmModel struct {
|
||||
srv *server
|
||||
server *server
|
||||
modelType string
|
||||
modelInfo inference.ModelInfo
|
||||
|
||||
|
|
@ -32,24 +32,24 @@ func (m *rocmModel) Generate(ctx context.Context, prompt string, opts ...inferen
|
|||
m.lastErr = nil
|
||||
m.mu.Unlock()
|
||||
|
||||
if !m.srv.alive() {
|
||||
if !m.server.alive() {
|
||||
m.setServerExitErr()
|
||||
return func(yield func(inference.Token) bool) {}
|
||||
}
|
||||
|
||||
cfg := inference.ApplyGenerateOpts(opts)
|
||||
config := inference.ApplyGenerateOpts(opts)
|
||||
|
||||
req := llamacpp.CompletionRequest{
|
||||
Prompt: prompt,
|
||||
MaxTokens: cfg.MaxTokens,
|
||||
Temperature: cfg.Temperature,
|
||||
TopK: cfg.TopK,
|
||||
TopP: cfg.TopP,
|
||||
RepeatPenalty: cfg.RepeatPenalty,
|
||||
MaxTokens: config.MaxTokens,
|
||||
Temperature: config.Temperature,
|
||||
TopK: config.TopK,
|
||||
TopP: config.TopP,
|
||||
RepeatPenalty: config.RepeatPenalty,
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
chunks, errFn := m.srv.client.Complete(ctx, req)
|
||||
chunks, errFn := m.server.client.Complete(ctx, req)
|
||||
|
||||
return func(yield func(inference.Token) bool) {
|
||||
var count int
|
||||
|
|
@ -75,12 +75,12 @@ func (m *rocmModel) Chat(ctx context.Context, messages []inference.Message, opts
|
|||
m.lastErr = nil
|
||||
m.mu.Unlock()
|
||||
|
||||
if !m.srv.alive() {
|
||||
if !m.server.alive() {
|
||||
m.setServerExitErr()
|
||||
return func(yield func(inference.Token) bool) {}
|
||||
}
|
||||
|
||||
cfg := inference.ApplyGenerateOpts(opts)
|
||||
config := inference.ApplyGenerateOpts(opts)
|
||||
|
||||
chatMsgs := make([]llamacpp.ChatMessage, len(messages))
|
||||
for i, msg := range messages {
|
||||
|
|
@ -92,15 +92,15 @@ func (m *rocmModel) Chat(ctx context.Context, messages []inference.Message, opts
|
|||
|
||||
req := llamacpp.ChatRequest{
|
||||
Messages: chatMsgs,
|
||||
MaxTokens: cfg.MaxTokens,
|
||||
Temperature: cfg.Temperature,
|
||||
TopK: cfg.TopK,
|
||||
TopP: cfg.TopP,
|
||||
RepeatPenalty: cfg.RepeatPenalty,
|
||||
MaxTokens: config.MaxTokens,
|
||||
Temperature: config.Temperature,
|
||||
TopK: config.TopK,
|
||||
TopP: config.TopP,
|
||||
RepeatPenalty: config.RepeatPenalty,
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
chunks, errFn := m.srv.client.ChatComplete(ctx, req)
|
||||
chunks, errFn := m.server.client.ChatComplete(ctx, req)
|
||||
|
||||
return func(yield func(inference.Token) bool) {
|
||||
var count int
|
||||
|
|
@ -124,7 +124,7 @@ func (m *rocmModel) Chat(ctx context.Context, messages []inference.Message, opts
|
|||
// Each prompt gets a single-token completion (max_tokens=1, temperature=0).
|
||||
// llama-server has no native classify endpoint, so this simulates it.
|
||||
func (m *rocmModel) Classify(ctx context.Context, prompts []string, opts ...inference.GenerateOption) ([]inference.ClassifyResult, error) {
|
||||
if !m.srv.alive() {
|
||||
if !m.server.alive() {
|
||||
m.setServerExitErr()
|
||||
return nil, m.Err()
|
||||
}
|
||||
|
|
@ -143,7 +143,7 @@ func (m *rocmModel) Classify(ctx context.Context, prompts []string, opts ...infe
|
|||
Temperature: 0,
|
||||
}
|
||||
|
||||
chunks, errFn := m.srv.client.Complete(ctx, req)
|
||||
chunks, errFn := m.server.client.Complete(ctx, req)
|
||||
var text strings.Builder
|
||||
for chunk := range chunks {
|
||||
text.WriteString(chunk)
|
||||
|
|
@ -164,12 +164,12 @@ func (m *rocmModel) Classify(ctx context.Context, prompts []string, opts ...infe
|
|||
// BatchGenerate runs batched autoregressive generation via llama-server.
|
||||
// Each prompt is decoded sequentially up to MaxTokens.
|
||||
func (m *rocmModel) BatchGenerate(ctx context.Context, prompts []string, opts ...inference.GenerateOption) ([]inference.BatchResult, error) {
|
||||
if !m.srv.alive() {
|
||||
if !m.server.alive() {
|
||||
m.setServerExitErr()
|
||||
return nil, m.Err()
|
||||
}
|
||||
|
||||
cfg := inference.ApplyGenerateOpts(opts)
|
||||
config := inference.ApplyGenerateOpts(opts)
|
||||
start := time.Now()
|
||||
results := make([]inference.BatchResult, len(prompts))
|
||||
var totalGenerated int
|
||||
|
|
@ -182,14 +182,14 @@ func (m *rocmModel) BatchGenerate(ctx context.Context, prompts []string, opts ..
|
|||
|
||||
req := llamacpp.CompletionRequest{
|
||||
Prompt: prompt,
|
||||
MaxTokens: cfg.MaxTokens,
|
||||
Temperature: cfg.Temperature,
|
||||
TopK: cfg.TopK,
|
||||
TopP: cfg.TopP,
|
||||
RepeatPenalty: cfg.RepeatPenalty,
|
||||
MaxTokens: config.MaxTokens,
|
||||
Temperature: config.Temperature,
|
||||
TopK: config.TopK,
|
||||
TopP: config.TopP,
|
||||
RepeatPenalty: config.RepeatPenalty,
|
||||
}
|
||||
|
||||
chunks, errFn := m.srv.client.Complete(ctx, req)
|
||||
chunks, errFn := m.server.client.Complete(ctx, req)
|
||||
var tokens []inference.Token
|
||||
for text := range chunks {
|
||||
tokens = append(tokens, inference.Token{Text: text})
|
||||
|
|
@ -227,15 +227,15 @@ func (m *rocmModel) Err() error {
|
|||
|
||||
// Close releases the llama-server subprocess and all associated resources.
|
||||
func (m *rocmModel) Close() error {
|
||||
return m.srv.stop()
|
||||
return m.server.stop()
|
||||
}
|
||||
|
||||
// setServerExitErr stores an appropriate error when the server is dead.
|
||||
func (m *rocmModel) setServerExitErr() {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
if m.srv.exitErr != nil {
|
||||
m.lastErr = coreerr.E("rocm.setServerExitErr", "server has exited", m.srv.exitErr)
|
||||
if m.server.exitErr != nil {
|
||||
m.lastErr = coreerr.E("rocm.setServerExitErr", "server has exited", m.server.exitErr)
|
||||
} else {
|
||||
m.lastErr = coreerr.E("rocm.setServerExitErr", "server has exited unexpectedly", nil)
|
||||
}
|
||||
|
|
@ -248,7 +248,7 @@ func (m *rocmModel) recordMetrics(promptTokens, generatedTokens int, start, deco
|
|||
decode := now.Sub(decodeStart)
|
||||
prefill := total - decode
|
||||
|
||||
met := inference.GenerateMetrics{
|
||||
metrics := inference.GenerateMetrics{
|
||||
PromptTokens: promptTokens,
|
||||
GeneratedTokens: generatedTokens,
|
||||
PrefillDuration: prefill,
|
||||
|
|
@ -256,19 +256,19 @@ func (m *rocmModel) recordMetrics(promptTokens, generatedTokens int, start, deco
|
|||
TotalDuration: total,
|
||||
}
|
||||
if prefill > 0 && promptTokens > 0 {
|
||||
met.PrefillTokensPerSec = float64(promptTokens) / prefill.Seconds()
|
||||
metrics.PrefillTokensPerSec = float64(promptTokens) / prefill.Seconds()
|
||||
}
|
||||
if decode > 0 && generatedTokens > 0 {
|
||||
met.DecodeTokensPerSec = float64(generatedTokens) / decode.Seconds()
|
||||
metrics.DecodeTokensPerSec = float64(generatedTokens) / decode.Seconds()
|
||||
}
|
||||
|
||||
// Try to get VRAM stats — best effort.
|
||||
if vram, err := GetVRAMInfo(); err == nil {
|
||||
met.PeakMemoryBytes = vram.Used
|
||||
met.ActiveMemoryBytes = vram.Used
|
||||
metrics.PeakMemoryBytes = vram.Used
|
||||
metrics.ActiveMemoryBytes = vram.Used
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
m.metrics = met
|
||||
m.metrics = metrics
|
||||
m.mu.Unlock()
|
||||
}
|
||||
|
|
|
|||
|
|
@ -8,5 +8,10 @@ func init() {
|
|||
inference.Register(&rocmBackend{})
|
||||
}
|
||||
|
||||
// ROCmAvailable reports whether ROCm GPU inference is available.
|
||||
// ROCmAvailable reports whether ROCm GPU inference is available on this machine.
|
||||
// Returns true only on linux/amd64 with /dev/kfd present and llama-server findable.
|
||||
//
|
||||
// if rocm.ROCmAvailable() {
|
||||
// m, err := inference.LoadModel("/data/model.gguf")
|
||||
// }
|
||||
func ROCmAvailable() bool { return true }
|
||||
|
|
|
|||
8
rocm.go
8
rocm.go
|
|
@ -25,6 +25,9 @@
|
|||
package rocm
|
||||
|
||||
// VRAMInfo reports GPU video memory usage in bytes.
|
||||
//
|
||||
// info, err := rocm.GetVRAMInfo()
|
||||
// fmt.Printf("VRAM: %d MiB used / %d MiB total", info.Used/(1024*1024), info.Total/(1024*1024))
|
||||
type VRAMInfo struct {
|
||||
Total uint64
|
||||
Used uint64
|
||||
|
|
@ -32,6 +35,11 @@ type VRAMInfo struct {
|
|||
}
|
||||
|
||||
// ModelInfo describes a GGUF model file discovered on disk.
|
||||
//
|
||||
// models, _ := rocm.DiscoverModels("/data/lem/gguf")
|
||||
// for _, m := range models {
|
||||
// fmt.Printf("%s (%s %s, ctx=%d)\n", m.Name, m.Architecture, m.Quantisation, m.ContextLen)
|
||||
// }
|
||||
type ModelInfo struct {
|
||||
Path string // full path to .gguf file
|
||||
Architecture string // GGUF architecture (e.g. "gemma3", "llama", "qwen2")
|
||||
|
|
|
|||
|
|
@ -4,8 +4,12 @@ package rocm
|
|||
|
||||
import coreerr "forge.lthn.ai/core/go-log"
|
||||
|
||||
// ROCmAvailable reports whether ROCm GPU inference is available.
|
||||
// ROCmAvailable reports whether ROCm GPU inference is available on this machine.
|
||||
// Returns false on non-Linux or non-amd64 platforms.
|
||||
//
|
||||
// if rocm.ROCmAvailable() {
|
||||
// m, err := inference.LoadModel("/data/model.gguf")
|
||||
// }
|
||||
func ROCmAvailable() bool { return false }
|
||||
|
||||
// GetVRAMInfo is not available on non-Linux/non-amd64 platforms.
|
||||
|
|
|
|||
|
|
@ -14,34 +14,34 @@ import (
|
|||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestFindLlamaServer_InPATH(t *testing.T) {
|
||||
func TestFindLlamaServer_Good_InPATH(t *testing.T) {
|
||||
// llama-server is at /usr/local/bin/llama-server on this machine.
|
||||
path, err := findLlamaServer()
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, path, "llama-server")
|
||||
}
|
||||
|
||||
func TestFindLlamaServer_EnvOverride(t *testing.T) {
|
||||
func TestFindLlamaServer_Good_EnvOverride(t *testing.T) {
|
||||
t.Setenv("ROCM_LLAMA_SERVER_PATH", "/usr/local/bin/llama-server")
|
||||
path, err := findLlamaServer()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "/usr/local/bin/llama-server", path)
|
||||
}
|
||||
|
||||
func TestFindLlamaServer_EnvNotFound(t *testing.T) {
|
||||
func TestFindLlamaServer_Bad_EnvPathMissing(t *testing.T) {
|
||||
t.Setenv("ROCM_LLAMA_SERVER_PATH", "/nonexistent/llama-server")
|
||||
_, err := findLlamaServer()
|
||||
assert.ErrorContains(t, err, "not found")
|
||||
}
|
||||
|
||||
func TestFreePort(t *testing.T) {
|
||||
func TestFreePort_Good(t *testing.T) {
|
||||
port, err := freePort()
|
||||
require.NoError(t, err)
|
||||
assert.Greater(t, port, 0)
|
||||
assert.Less(t, port, 65536)
|
||||
}
|
||||
|
||||
func TestFreePort_UniquePerCall(t *testing.T) {
|
||||
func TestFreePort_Good_UniquePerCall(t *testing.T) {
|
||||
p1, err := freePort()
|
||||
require.NoError(t, err)
|
||||
p2, err := freePort()
|
||||
|
|
@ -50,7 +50,7 @@ func TestFreePort_UniquePerCall(t *testing.T) {
|
|||
_ = p2
|
||||
}
|
||||
|
||||
func TestServerEnv_HIPVisibleDevices(t *testing.T) {
|
||||
func TestServerEnv_Good_SetsHIPVisibleDevices(t *testing.T) {
|
||||
env := serverEnv()
|
||||
var hipVals []string
|
||||
for _, e := range env {
|
||||
|
|
@ -61,7 +61,7 @@ func TestServerEnv_HIPVisibleDevices(t *testing.T) {
|
|||
assert.Equal(t, []string{"HIP_VISIBLE_DEVICES=0"}, hipVals)
|
||||
}
|
||||
|
||||
func TestServerEnv_FiltersExistingHIP(t *testing.T) {
|
||||
func TestServerEnv_Good_FiltersExistingHIPVisibleDevices(t *testing.T) {
|
||||
t.Setenv("HIP_VISIBLE_DEVICES", "1")
|
||||
env := serverEnv()
|
||||
var hipVals []string
|
||||
|
|
@ -73,7 +73,7 @@ func TestServerEnv_FiltersExistingHIP(t *testing.T) {
|
|||
assert.Equal(t, []string{"HIP_VISIBLE_DEVICES=0"}, hipVals)
|
||||
}
|
||||
|
||||
func TestAvailable(t *testing.T) {
|
||||
func TestAvailable_Good(t *testing.T) {
|
||||
b := &rocmBackend{}
|
||||
if _, err := os.Stat("/dev/kfd"); err != nil {
|
||||
t.Skip("no ROCm hardware")
|
||||
|
|
@ -81,27 +81,26 @@ func TestAvailable(t *testing.T) {
|
|||
assert.True(t, b.Available())
|
||||
}
|
||||
|
||||
|
||||
func TestServerAlive_Running(t *testing.T) {
|
||||
func TestServerAlive_Good_Running(t *testing.T) {
|
||||
s := &server{exited: make(chan struct{})}
|
||||
assert.True(t, s.alive())
|
||||
}
|
||||
|
||||
func TestServerAlive_Exited(t *testing.T) {
|
||||
func TestServerAlive_Good_Exited(t *testing.T) {
|
||||
exited := make(chan struct{})
|
||||
close(exited)
|
||||
s := &server{exited: exited, exitErr: coreerr.E("test", "process killed", nil)}
|
||||
assert.False(t, s.alive())
|
||||
}
|
||||
|
||||
func TestGenerate_ServerDead(t *testing.T) {
|
||||
func TestGenerate_Bad_ServerDead(t *testing.T) {
|
||||
exited := make(chan struct{})
|
||||
close(exited)
|
||||
s := &server{
|
||||
exited: exited,
|
||||
exitErr: coreerr.E("test", "process killed", nil),
|
||||
}
|
||||
m := &rocmModel{srv: s}
|
||||
m := &rocmModel{server: s}
|
||||
|
||||
var count int
|
||||
for range m.Generate(context.Background(), "hello") {
|
||||
|
|
@ -111,7 +110,7 @@ func TestGenerate_ServerDead(t *testing.T) {
|
|||
assert.ErrorContains(t, m.Err(), "server has exited")
|
||||
}
|
||||
|
||||
func TestStartServer_RetriesOnProcessExit(t *testing.T) {
|
||||
func TestStartServer_Ugly_RetriesOnProcessExit(t *testing.T) {
|
||||
// /bin/false starts successfully but exits immediately with code 1.
|
||||
// startServer should retry up to 3 times, then fail.
|
||||
_, err := startServer("/bin/false", "/nonexistent/model.gguf", 999, 0, 0)
|
||||
|
|
@ -119,14 +118,14 @@ func TestStartServer_RetriesOnProcessExit(t *testing.T) {
|
|||
assert.Contains(t, err.Error(), "failed after 3 attempts")
|
||||
}
|
||||
|
||||
func TestChat_ServerDead(t *testing.T) {
|
||||
func TestChat_Bad_ServerDead(t *testing.T) {
|
||||
exited := make(chan struct{})
|
||||
close(exited)
|
||||
s := &server{
|
||||
exited: exited,
|
||||
exitErr: coreerr.E("test", "process killed", nil),
|
||||
}
|
||||
m := &rocmModel{srv: s}
|
||||
m := &rocmModel{server: s}
|
||||
|
||||
msgs := []inference.Message{{Role: "user", Content: "hello"}}
|
||||
var count int
|
||||
|
|
|
|||
4
vram.go
4
vram.go
|
|
@ -17,6 +17,10 @@ import (
|
|||
//
|
||||
// Note: total and used are read non-atomically from sysfs; transient
|
||||
// inconsistencies are possible under heavy allocation churn.
|
||||
//
|
||||
// info, err := rocm.GetVRAMInfo()
|
||||
// fmt.Printf("VRAM: %d MiB used / %d MiB total (free: %d MiB)",
|
||||
// info.Used/(1024*1024), info.Total/(1024*1024), info.Free/(1024*1024))
|
||||
func GetVRAMInfo() (VRAMInfo, error) {
|
||||
cards, err := filepath.Glob("/sys/class/drm/card[0-9]*/device/mem_info_vram_total")
|
||||
if err != nil {
|
||||
|
|
|
|||
10
vram_test.go
10
vram_test.go
|
|
@ -11,7 +11,7 @@ import (
|
|||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestReadSysfsUint64(t *testing.T) {
|
||||
func TestReadSysfsUint64_Good(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "test_value")
|
||||
require.NoError(t, os.WriteFile(path, []byte("17163091968\n"), 0644))
|
||||
|
|
@ -21,12 +21,12 @@ func TestReadSysfsUint64(t *testing.T) {
|
|||
assert.Equal(t, uint64(17163091968), val)
|
||||
}
|
||||
|
||||
func TestReadSysfsUint64_NotFound(t *testing.T) {
|
||||
func TestReadSysfsUint64_Bad_NotFound(t *testing.T) {
|
||||
_, err := readSysfsUint64("/nonexistent/path")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestReadSysfsUint64_InvalidContent(t *testing.T) {
|
||||
func TestReadSysfsUint64_Bad_InvalidContent(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "bad_value")
|
||||
require.NoError(t, os.WriteFile(path, []byte("not-a-number\n"), 0644))
|
||||
|
|
@ -35,7 +35,7 @@ func TestReadSysfsUint64_InvalidContent(t *testing.T) {
|
|||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestReadSysfsUint64_EmptyFile(t *testing.T) {
|
||||
func TestReadSysfsUint64_Bad_EmptyFile(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "empty_value")
|
||||
require.NoError(t, os.WriteFile(path, []byte(""), 0644))
|
||||
|
|
@ -44,7 +44,7 @@ func TestReadSysfsUint64_EmptyFile(t *testing.T) {
|
|||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestGetVRAMInfo(t *testing.T) {
|
||||
func TestGetVRAMInfo_Good(t *testing.T) {
|
||||
info, err := GetVRAMInfo()
|
||||
if err != nil {
|
||||
t.Skipf("no VRAM sysfs info available: %v", err)
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue