Co-Authored-By: Virgil <virgil@lethean.io> Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
30 KiB
Phase 1: Core Implementation Plan
For Claude: REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task.
Goal: Implement the full inference.Backend and inference.TextModel interfaces for AMD ROCm GPUs using llama-server as the subprocess engine.
Architecture: Bottom-up layered design. internal/llamacpp/ is a pure HTTP+SSE client with no go-inference dependency. server.go manages the llama-server process lifecycle. model.go implements inference.TextModel by wrapping the server and client. backend.go implements inference.Backend by wiring everything together.
Tech Stack: Go 1.25, llama-server (llama.cpp with HIP/ROCm), OpenAI-compatible API, SSE streaming, testify for assertions, httptest for mocking.
Critical environment note: HIP_VISIBLE_DEVICES=0 MUST be set when spawning llama-server. The Ryzen 9 9950X iGPU (Device 1) reports 100GB free memory (system RAM) and crashes llama-server when it tries to split tensors across devices.
Task 1: Health Check Client
Files:
- Create:
internal/llamacpp/health.go - Create:
internal/llamacpp/health_test.go
Step 1: Create directory and write the failing test
Create internal/llamacpp/ directory, then write the test:
// internal/llamacpp/health_test.go
package llamacpp
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestHealth_OK(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "/health", r.URL.Path)
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(`{"status":"ok"}`))
}))
defer ts.Close()
c := NewClient(ts.URL)
err := c.Health(context.Background())
require.NoError(t, err)
}
func TestHealth_NotReady(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(`{"status":"loading model"}`))
}))
defer ts.Close()
c := NewClient(ts.URL)
err := c.Health(context.Background())
assert.ErrorContains(t, err, "not ready")
}
func TestHealth_ServerDown(t *testing.T) {
c := NewClient("http://127.0.0.1:1") // nothing listening
err := c.Health(context.Background())
assert.Error(t, err)
}
Step 2: Run test to verify it fails
Run: go test ./internal/llamacpp/ -run TestHealth -v
Expected: FAIL — NewClient and Health not defined.
Step 3: Write minimal implementation
// internal/llamacpp/health.go
package llamacpp
import (
"context"
"encoding/json"
"fmt"
"net/http"
"strings"
)
// Client communicates with a llama-server instance.
type Client struct {
baseURL string
httpClient *http.Client
}
// NewClient creates a client for the llama-server at the given base URL.
func NewClient(baseURL string) *Client {
return &Client{
baseURL: strings.TrimRight(baseURL, "/"),
httpClient: &http.Client{},
}
}
type healthResponse struct {
Status string `json:"status"`
}
// Health checks whether the llama-server is ready to accept requests.
func (c *Client) Health(ctx context.Context) error {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.baseURL+"/health", nil)
if err != nil {
return err
}
resp, err := c.httpClient.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("llamacpp: health returned %d", resp.StatusCode)
}
var h healthResponse
if err := json.NewDecoder(resp.Body).Decode(&h); err != nil {
return fmt.Errorf("llamacpp: health decode: %w", err)
}
if h.Status != "ok" {
return fmt.Errorf("llamacpp: server not ready (status: %s)", h.Status)
}
return nil
}
Step 4: Run test to verify it passes
Run: go test ./internal/llamacpp/ -run TestHealth -v
Expected: PASS (all 3 tests)
Step 5: Commit
git add internal/llamacpp/health.go internal/llamacpp/health_test.go
git commit -m "feat: llamacpp health check client"
Task 2: SSE Streaming + Chat Completions
Files:
- Create:
internal/llamacpp/client.go - Create:
internal/llamacpp/client_test.go
Step 1: Write the failing tests
// internal/llamacpp/client_test.go
package llamacpp
import (
"context"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestChatComplete_Streaming(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "/v1/chat/completions", r.URL.Path)
assert.Equal(t, "application/json", r.Header.Get("Content-Type"))
flusher, ok := w.(http.Flusher)
require.True(t, ok)
w.Header().Set("Content-Type", "text/event-stream")
w.WriteHeader(http.StatusOK)
chunks := []string{
`{"choices":[{"delta":{"content":"Hello"},"finish_reason":null}]}`,
`{"choices":[{"delta":{"content":" world"},"finish_reason":null}]}`,
`{"choices":[{"delta":{},"finish_reason":"stop"}]}`,
}
for _, c := range chunks {
fmt.Fprintf(w, "data: %s\n\n", c)
flusher.Flush()
}
fmt.Fprint(w, "data: [DONE]\n\n")
flusher.Flush()
}))
defer ts.Close()
c := NewClient(ts.URL)
chunks, errFn := c.ChatComplete(context.Background(), ChatRequest{
Messages: []ChatMessage{{Role: "user", Content: "Hi"}},
MaxTokens: 32,
})
var result []string
for text := range chunks {
result = append(result, text)
}
require.NoError(t, errFn())
assert.Equal(t, []string{"Hello", " world"}, result)
}
func TestChatComplete_EmptyResponse(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
fmt.Fprint(w, "data: [DONE]\n\n")
}))
defer ts.Close()
c := NewClient(ts.URL)
chunks, errFn := c.ChatComplete(context.Background(), ChatRequest{
Messages: []ChatMessage{{Role: "user", Content: "Hi"}},
MaxTokens: 1,
})
var result []string
for text := range chunks {
result = append(result, text)
}
require.NoError(t, errFn())
assert.Empty(t, result)
}
func TestChatComplete_HTTPError(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
}))
defer ts.Close()
c := NewClient(ts.URL)
chunks, errFn := c.ChatComplete(context.Background(), ChatRequest{
Messages: []ChatMessage{{Role: "user", Content: "Hi"}},
})
for range chunks {
t.Fatal("expected no chunks")
}
assert.ErrorContains(t, errFn(), "500")
}
func TestChatComplete_ContextCancelled(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
flusher, _ := w.(http.Flusher)
w.Header().Set("Content-Type", "text/event-stream")
w.WriteHeader(http.StatusOK)
// Send one chunk then block (context will cancel)
fmt.Fprint(w, "data: {\"choices\":[{\"delta\":{\"content\":\"Hi\"},\"finish_reason\":null}]}\n\n")
flusher.Flush()
// Block until client disconnects
<-r.Context().Done()
}))
defer ts.Close()
ctx, cancel := context.WithCancel(context.Background())
c := NewClient(ts.URL)
chunks, _ := c.ChatComplete(ctx, ChatRequest{
Messages: []ChatMessage{{Role: "user", Content: "Hi"}},
})
var got int
for range chunks {
got++
cancel() // Cancel after first chunk
}
assert.Equal(t, 1, got)
}
Step 2: Run tests to verify they fail
Run: go test ./internal/llamacpp/ -run TestChat -v
Expected: FAIL — ChatComplete, ChatRequest, ChatMessage not defined.
Step 3: Write the implementation
// internal/llamacpp/client.go
package llamacpp
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"iter"
"net/http"
"strings"
)
// ChatMessage is a single message in a conversation.
type ChatMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
// ChatRequest is the request body for /v1/chat/completions.
type ChatRequest struct {
Messages []ChatMessage `json:"messages"`
MaxTokens int `json:"max_tokens,omitempty"`
Temperature float32 `json:"temperature"`
TopK int `json:"top_k,omitempty"`
TopP float32 `json:"top_p,omitempty"`
RepeatPenalty float32 `json:"repeat_penalty,omitempty"`
Stream bool `json:"stream"`
}
// CompletionRequest is the request body for /v1/completions.
type CompletionRequest struct {
Prompt string `json:"prompt"`
MaxTokens int `json:"max_tokens,omitempty"`
Temperature float32 `json:"temperature"`
TopK int `json:"top_k,omitempty"`
TopP float32 `json:"top_p,omitempty"`
RepeatPenalty float32 `json:"repeat_penalty,omitempty"`
Stream bool `json:"stream"`
}
type chatChunkResponse struct {
Choices []struct {
Delta struct {
Content string `json:"content"`
} `json:"delta"`
FinishReason *string `json:"finish_reason"`
} `json:"choices"`
}
type completionChunkResponse struct {
Choices []struct {
Text string `json:"text"`
FinishReason *string `json:"finish_reason"`
} `json:"choices"`
}
// ChatComplete sends a streaming chat completion request.
// Returns an iterator of text chunks and a function to retrieve any
// error that occurred during streaming. Call errFn after iteration.
func (c *Client) ChatComplete(ctx context.Context, req ChatRequest) (iter.Seq[string], func() error) {
req.Stream = true
body, err := json.Marshal(req)
if err != nil {
return noChunks, func() error { return fmt.Errorf("llamacpp: marshal: %w", err) }
}
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+"/v1/chat/completions", bytes.NewReader(body))
if err != nil {
return noChunks, func() error { return err }
}
httpReq.Header.Set("Content-Type", "application/json")
resp, err := c.httpClient.Do(httpReq)
if err != nil {
return noChunks, func() error { return err }
}
if resp.StatusCode != http.StatusOK {
resp.Body.Close()
return noChunks, func() error {
return fmt.Errorf("llamacpp: chat request returned %d", resp.StatusCode)
}
}
var streamErr error
seq := func(yield func(string) bool) {
defer resp.Body.Close()
for payload := range parseSSE(resp.Body, &streamErr) {
var chunk chatChunkResponse
if err := json.Unmarshal([]byte(payload), &chunk); err != nil {
streamErr = fmt.Errorf("llamacpp: parse chunk: %w", err)
return
}
if len(chunk.Choices) > 0 && chunk.Choices[0].Delta.Content != "" {
if !yield(chunk.Choices[0].Delta.Content) {
return
}
}
}
}
return seq, func() error { return streamErr }
}
// Complete sends a streaming text completion request.
// Returns an iterator of text chunks and a function to retrieve any
// error that occurred during streaming.
func (c *Client) Complete(ctx context.Context, req CompletionRequest) (iter.Seq[string], func() error) {
req.Stream = true
body, err := json.Marshal(req)
if err != nil {
return noChunks, func() error { return fmt.Errorf("llamacpp: marshal: %w", err) }
}
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+"/v1/completions", bytes.NewReader(body))
if err != nil {
return noChunks, func() error { return err }
}
httpReq.Header.Set("Content-Type", "application/json")
resp, err := c.httpClient.Do(httpReq)
if err != nil {
return noChunks, func() error { return err }
}
if resp.StatusCode != http.StatusOK {
resp.Body.Close()
return noChunks, func() error {
return fmt.Errorf("llamacpp: completion request returned %d", resp.StatusCode)
}
}
var streamErr error
seq := func(yield func(string) bool) {
defer resp.Body.Close()
for payload := range parseSSE(resp.Body, &streamErr) {
var chunk completionChunkResponse
if err := json.Unmarshal([]byte(payload), &chunk); err != nil {
streamErr = fmt.Errorf("llamacpp: parse chunk: %w", err)
return
}
if len(chunk.Choices) > 0 && chunk.Choices[0].Text != "" {
if !yield(chunk.Choices[0].Text) {
return
}
}
}
}
return seq, func() error { return streamErr }
}
// parseSSE reads SSE data lines from r, yielding each data payload.
// Stops on "data: [DONE]" or reader exhaustion. Sets *errOut on read error.
func parseSSE(r io.Reader, errOut *error) iter.Seq[string] {
return func(yield func(string) bool) {
scanner := bufio.NewScanner(r)
for scanner.Scan() {
line := scanner.Text()
if !strings.HasPrefix(line, "data: ") {
continue
}
payload := strings.TrimPrefix(line, "data: ")
if payload == "[DONE]" {
return
}
if !yield(payload) {
return
}
}
if err := scanner.Err(); err != nil {
*errOut = fmt.Errorf("llamacpp: read stream: %w", err)
}
}
}
// noChunks is an empty iterator for error cases.
func noChunks(func(string) bool) {}
Step 4: Run tests to verify they pass
Run: go test ./internal/llamacpp/ -run TestChat -v
Expected: PASS (all 4 tests)
Step 5: Commit
git add internal/llamacpp/client.go internal/llamacpp/client_test.go
git commit -m "feat: llamacpp SSE streaming client for chat completions"
Task 3: Text Completion Streaming Tests
Files:
- Modify:
internal/llamacpp/client_test.go(add completion tests)
The Complete method was already implemented in Task 2. This task adds its tests.
Step 1: Write the failing tests
Append to internal/llamacpp/client_test.go:
func TestComplete_Streaming(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "/v1/completions", r.URL.Path)
flusher, ok := w.(http.Flusher)
require.True(t, ok)
w.Header().Set("Content-Type", "text/event-stream")
w.WriteHeader(http.StatusOK)
chunks := []string{
`{"choices":[{"text":"Once","finish_reason":null}]}`,
`{"choices":[{"text":" upon","finish_reason":null}]}`,
`{"choices":[{"text":" a time","finish_reason":null}]}`,
`{"choices":[{"text":"","finish_reason":"stop"}]}`,
}
for _, c := range chunks {
fmt.Fprintf(w, "data: %s\n\n", c)
flusher.Flush()
}
fmt.Fprint(w, "data: [DONE]\n\n")
flusher.Flush()
}))
defer ts.Close()
c := NewClient(ts.URL)
chunks, errFn := c.Complete(context.Background(), CompletionRequest{
Prompt: "Once",
MaxTokens: 32,
})
var result []string
for text := range chunks {
result = append(result, text)
}
require.NoError(t, errFn())
assert.Equal(t, []string{"Once", " upon", " a time"}, result)
}
func TestComplete_HTTPError(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusBadRequest)
}))
defer ts.Close()
c := NewClient(ts.URL)
chunks, errFn := c.Complete(context.Background(), CompletionRequest{
Prompt: "Hello",
})
for range chunks {
t.Fatal("expected no chunks")
}
assert.ErrorContains(t, errFn(), "400")
}
Step 2: Run tests to verify they pass
Run: go test ./internal/llamacpp/ -run TestComplete -v
Expected: PASS (both tests — implementation already exists from Task 2)
Step 3: Commit
git add internal/llamacpp/client_test.go
git commit -m "test: completion streaming tests for llamacpp client"
Task 4: Server Helpers
Files:
- Create:
server.go - Create:
server_test.go
These are the utility functions that don't require a GPU: finding the llama-server binary and getting a free port.
Step 1: Write the failing tests
// server_test.go
//go:build linux && amd64
package rocm
import (
"os"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestFindLlamaServer_InPATH(t *testing.T) {
// llama-server is installed at /usr/local/bin/llama-server on this machine
path, err := findLlamaServer()
require.NoError(t, err)
assert.Contains(t, path, "llama-server")
}
func TestFindLlamaServer_EnvOverride(t *testing.T) {
t.Setenv("ROCM_LLAMA_SERVER_PATH", "/usr/local/bin/llama-server")
path, err := findLlamaServer()
require.NoError(t, err)
assert.Equal(t, "/usr/local/bin/llama-server", path)
}
func TestFindLlamaServer_EnvNotFound(t *testing.T) {
t.Setenv("ROCM_LLAMA_SERVER_PATH", "/nonexistent/llama-server")
_, err := findLlamaServer()
assert.ErrorContains(t, err, "not found")
}
func TestFreePort(t *testing.T) {
port, err := freePort()
require.NoError(t, err)
assert.Greater(t, port, 0)
assert.Less(t, port, 65536)
}
func TestFreePort_UniquePerCall(t *testing.T) {
p1, err := freePort()
require.NoError(t, err)
p2, err := freePort()
require.NoError(t, err)
// Not guaranteed but extremely likely with random kernel allocation
_ = p1
_ = p2
}
func TestServerEnv_HIPVisibleDevices(t *testing.T) {
env := serverEnv()
var hipVals []string
for _, e := range env {
if strings.HasPrefix(e, "HIP_VISIBLE_DEVICES=") {
hipVals = append(hipVals, e)
}
}
assert.Equal(t, []string{"HIP_VISIBLE_DEVICES=0"}, hipVals)
}
func TestServerEnv_FiltersExistingHIP(t *testing.T) {
t.Setenv("HIP_VISIBLE_DEVICES", "1")
env := serverEnv()
var hipVals []string
for _, e := range env {
if strings.HasPrefix(e, "HIP_VISIBLE_DEVICES=") {
hipVals = append(hipVals, e)
}
}
// Must filter the old value and only have our override
assert.Equal(t, []string{"HIP_VISIBLE_DEVICES=0"}, hipVals)
}
Step 2: Run tests to verify they fail
Run: go test -run "TestFind|TestFree|TestServer" -v
Expected: FAIL — findLlamaServer, freePort, serverEnv not defined.
Step 3: Write minimal implementation
// server.go
//go:build linux && amd64
package rocm
import (
"context"
"fmt"
"net"
"os"
"os/exec"
"strconv"
"strings"
"syscall"
"time"
"forge.lthn.ai/core/go-rocm/internal/llamacpp"
)
// server manages a llama-server subprocess.
type server struct {
cmd *exec.Cmd
port int
client *llamacpp.Client
exited chan struct{}
exitErr error
}
// findLlamaServer locates the llama-server binary.
// Checks ROCM_LLAMA_SERVER_PATH first, then PATH.
func findLlamaServer() (string, error) {
if p := os.Getenv("ROCM_LLAMA_SERVER_PATH"); p != "" {
if _, err := os.Stat(p); err == nil {
return p, nil
}
return "", fmt.Errorf("rocm: ROCM_LLAMA_SERVER_PATH=%q not found", p)
}
p, err := exec.LookPath("llama-server")
if err != nil {
return "", fmt.Errorf("rocm: llama-server not in PATH: %w", err)
}
return p, nil
}
// freePort asks the kernel for a free TCP port on localhost.
func freePort() (int, error) {
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
return 0, err
}
port := l.Addr().(*net.TCPAddr).Port
l.Close()
return port, nil
}
// serverEnv returns the environment for the llama-server subprocess.
// Filters any existing HIP_VISIBLE_DEVICES and sets it to 0 to mask the iGPU.
func serverEnv() []string {
env := os.Environ()
filtered := make([]string, 0, len(env)+1)
for _, e := range env {
if !strings.HasPrefix(e, "HIP_VISIBLE_DEVICES=") {
filtered = append(filtered, e)
}
}
return append(filtered, "HIP_VISIBLE_DEVICES=0")
}
// startServer spawns llama-server and waits for it to become ready.
func startServer(binary, modelPath string, port, gpuLayers, ctxSize int) (*server, error) {
ngl := gpuLayers
if ngl < 0 {
ngl = 999 // all layers
}
args := []string{
"--model", modelPath,
"--host", "127.0.0.1",
"--port", strconv.Itoa(port),
"--n-gpu-layers", strconv.Itoa(ngl),
}
if ctxSize > 0 {
args = append(args, "--ctx-size", strconv.Itoa(ctxSize))
}
cmd := exec.Command(binary, args...)
cmd.Env = serverEnv()
if err := cmd.Start(); err != nil {
return nil, fmt.Errorf("rocm: start llama-server: %w", err)
}
client := llamacpp.NewClient(fmt.Sprintf("http://127.0.0.1:%d", port))
s := &server{
cmd: cmd,
port: port,
client: client,
exited: make(chan struct{}),
}
go func() {
s.exitErr = cmd.Wait()
close(s.exited)
}()
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
defer cancel()
if err := s.waitReady(ctx); err != nil {
s.stop()
return nil, err
}
return s, nil
}
// waitReady polls the health endpoint until the server is ready.
func (s *server) waitReady(ctx context.Context) error {
ticker := time.NewTicker(100 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return fmt.Errorf("rocm: llama-server health timeout: %w", ctx.Err())
case <-s.exited:
return fmt.Errorf("rocm: llama-server exited during startup: %v", s.exitErr)
case <-ticker.C:
if err := s.client.Health(ctx); err == nil {
return nil
}
}
}
}
// stop sends SIGTERM and waits up to 5s, then SIGKILL.
func (s *server) stop() error {
if s.cmd.Process == nil {
return nil
}
_ = s.cmd.Process.Signal(syscall.SIGTERM)
select {
case <-time.After(5 * time.Second):
_ = s.cmd.Process.Kill()
<-s.exited
case <-s.exited:
}
return nil
}
Step 4: Run tests to verify they pass
Run: go test -run "TestFind|TestFree|TestServer" -v
Expected: PASS (all tests). Note: TestFindLlamaServer_InPATH requires llama-server installed.
Step 5: Add testify dependency and commit
go get github.com/stretchr/testify
go mod tidy
git add server.go server_test.go go.mod go.sum
git commit -m "feat: server lifecycle and helpers for llama-server subprocess"
Task 5: TextModel Implementation
Files:
- Create:
model.go
No unit test for this file — it's a thin wrapper that maps inference types to llamacpp types. Tested via integration test in Task 7.
Step 1: Write the implementation
// model.go
//go:build linux && amd64
package rocm
import (
"context"
"iter"
"sync"
"forge.lthn.ai/core/go-inference"
"forge.lthn.ai/core/go-rocm/internal/llamacpp"
)
// rocmModel implements inference.TextModel using a llama-server subprocess.
type rocmModel struct {
srv *server
modelType string
mu sync.Mutex
lastErr error
}
func (m *rocmModel) Generate(ctx context.Context, prompt string, opts ...inference.GenerateOption) iter.Seq[inference.Token] {
cfg := inference.ApplyGenerateOpts(opts)
req := llamacpp.CompletionRequest{
Prompt: prompt,
MaxTokens: cfg.MaxTokens,
Temperature: cfg.Temperature,
TopK: cfg.TopK,
TopP: cfg.TopP,
RepeatPenalty: cfg.RepeatPenalty,
}
chunks, errFn := m.srv.client.Complete(ctx, req)
return func(yield func(inference.Token) bool) {
for text := range chunks {
if !yield(inference.Token{Text: text}) {
break
}
}
if err := errFn(); err != nil {
m.mu.Lock()
m.lastErr = err
m.mu.Unlock()
}
}
}
func (m *rocmModel) Chat(ctx context.Context, messages []inference.Message, opts ...inference.GenerateOption) iter.Seq[inference.Token] {
cfg := inference.ApplyGenerateOpts(opts)
chatMsgs := make([]llamacpp.ChatMessage, len(messages))
for i, msg := range messages {
chatMsgs[i] = llamacpp.ChatMessage{Role: msg.Role, Content: msg.Content}
}
req := llamacpp.ChatRequest{
Messages: chatMsgs,
MaxTokens: cfg.MaxTokens,
Temperature: cfg.Temperature,
TopK: cfg.TopK,
TopP: cfg.TopP,
RepeatPenalty: cfg.RepeatPenalty,
}
chunks, errFn := m.srv.client.ChatComplete(ctx, req)
return func(yield func(inference.Token) bool) {
for text := range chunks {
if !yield(inference.Token{Text: text}) {
break
}
}
if err := errFn(); err != nil {
m.mu.Lock()
m.lastErr = err
m.mu.Unlock()
}
}
}
func (m *rocmModel) ModelType() string { return m.modelType }
func (m *rocmModel) Err() error {
m.mu.Lock()
defer m.mu.Unlock()
return m.lastErr
}
func (m *rocmModel) Close() error {
return m.srv.stop()
}
Step 2: Verify it compiles
Run: go build ./...
Expected: No errors.
Step 3: Commit
git add model.go
git commit -m "feat: TextModel implementation wrapping llama-server"
Task 6: Backend Available() and LoadModel()
Files:
- Modify:
backend.go(replace stubs with real implementation)
Step 1: Update backend.go
Replace the entire file:
// backend.go
//go:build linux && amd64
package rocm
import (
"fmt"
"os"
"path/filepath"
"strings"
"forge.lthn.ai/core/go-inference"
)
// rocmBackend implements inference.Backend for AMD ROCm GPUs.
type rocmBackend struct{}
func (b *rocmBackend) Name() string { return "rocm" }
func (b *rocmBackend) Available() bool {
if _, err := os.Stat("/dev/kfd"); err != nil {
return false
}
if _, err := findLlamaServer(); err != nil {
return false
}
return true
}
func (b *rocmBackend) LoadModel(path string, opts ...inference.LoadOption) (inference.TextModel, error) {
cfg := inference.ApplyLoadOpts(opts)
binary, err := findLlamaServer()
if err != nil {
return nil, err
}
port, err := freePort()
if err != nil {
return nil, fmt.Errorf("rocm: find free port: %w", err)
}
srv, err := startServer(binary, path, port, cfg.GPULayers, cfg.ContextLen)
if err != nil {
return nil, err
}
return &rocmModel{
srv: srv,
modelType: guessModelType(path),
}, nil
}
// guessModelType extracts a model type hint from the filename.
// e.g. "LEK-Gemma3-4B-Q4_K_M.gguf" -> "gemma3"
func guessModelType(path string) string {
name := strings.ToLower(filepath.Base(path))
for _, arch := range []string{"gemma3", "gemma2", "gemma", "qwen3", "qwen2", "qwen", "llama3", "llama", "mistral", "phi"} {
if strings.Contains(name, arch) {
return arch
}
}
return "unknown"
}
Step 2: Verify it compiles
Run: go build ./...
Expected: No errors.
Step 3: Write a quick test for Available() and guessModelType()
Add to server_test.go:
func TestAvailable(t *testing.T) {
b := &rocmBackend{}
// On this machine (snider-linux): /dev/kfd exists, llama-server in PATH
// This test will fail on machines without ROCm
if _, err := os.Stat("/dev/kfd"); err != nil {
t.Skip("no ROCm hardware")
}
assert.True(t, b.Available())
}
func TestGuessModelType(t *testing.T) {
tests := []struct {
path string
expected string
}{
{"/data/lem/gguf/LEK-Gemma3-4B-Q4_K_M.gguf", "gemma3"},
{"/data/models/Qwen3-8B-Q4_K_M.gguf", "qwen3"},
{"/data/models/Llama-3.1-8B-Q4_K_M.gguf", "llama3"},
{"/data/models/Mistral-7B-v0.3-Q4_K_M.gguf", "mistral"},
{"/data/models/random-model.gguf", "unknown"},
}
for _, tt := range tests {
t.Run(tt.expected, func(t *testing.T) {
assert.Equal(t, tt.expected, guessModelType(tt.path))
})
}
}
Step 4: Run tests
Run: go test -run "TestAvailable|TestGuessModel" -v
Expected: PASS
Step 5: Commit
git add backend.go server_test.go
git commit -m "feat: Backend Available() and LoadModel() with GPU detection"
Task 7: Integration Test
Files:
- Create:
rocm_integration_test.go
This test loads a real model on the GPU, generates tokens, and verifies the full pipeline.
Step 1: Write the integration test
// rocm_integration_test.go
//go:build rocm
package rocm
import (
"context"
"os"
"testing"
"time"
"forge.lthn.ai/core/go-inference"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
const testModel = "/data/lem/gguf/LEK-Gemma3-1B-layered-v2-Q5_K_M.gguf"
func skipIfNoModel(t *testing.T) {
t.Helper()
if _, err := os.Stat(testModel); os.IsNotExist(err) {
t.Skip("test model not available (SMB mount down?)")
}
}
func skipIfNoROCm(t *testing.T) {
t.Helper()
if _, err := os.Stat("/dev/kfd"); err != nil {
t.Skip("no ROCm hardware")
}
if _, err := findLlamaServer(); err != nil {
t.Skip("llama-server not found")
}
}
func TestROCm_LoadAndGenerate(t *testing.T) {
skipIfNoROCm(t)
skipIfNoModel(t)
b := &rocmBackend{}
require.True(t, b.Available())
m, err := b.LoadModel(testModel, inference.WithContextLen(2048))
require.NoError(t, err)
defer m.Close()
assert.Equal(t, "gemma3", m.ModelType())
// Generate with a simple prompt
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
var tokens []string
for tok := range m.Generate(ctx, "The capital of France is", inference.WithMaxTokens(16)) {
tokens = append(tokens, tok.Text)
}
require.NoError(t, m.Err())
require.NotEmpty(t, tokens, "expected at least one token")
full := ""
for _, tok := range tokens {
full += tok
}
t.Logf("Generated: %s", full)
}
func TestROCm_Chat(t *testing.T) {
skipIfNoROCm(t)
skipIfNoModel(t)
b := &rocmBackend{}
m, err := b.LoadModel(testModel, inference.WithContextLen(2048))
require.NoError(t, err)
defer m.Close()
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
messages := []inference.Message{
{Role: "user", Content: "Say hello in exactly three words."},
}
var tokens []string
for tok := range m.Chat(ctx, messages, inference.WithMaxTokens(32)) {
tokens = append(tokens, tok.Text)
}
require.NoError(t, m.Err())
require.NotEmpty(t, tokens, "expected at least one token")
full := ""
for _, tok := range tokens {
full += tok
}
t.Logf("Chat response: %s", full)
}
func TestROCm_ContextCancellation(t *testing.T) {
skipIfNoROCm(t)
skipIfNoModel(t)
b := &rocmBackend{}
m, err := b.LoadModel(testModel, inference.WithContextLen(2048))
require.NoError(t, err)
defer m.Close()
ctx, cancel := context.WithCancel(context.Background())
var count int
for tok := range m.Generate(ctx, "Write a very long story about dragons", inference.WithMaxTokens(256)) {
_ = tok
count++
if count >= 3 {
cancel()
}
}
t.Logf("Got %d tokens before cancel", count)
assert.GreaterOrEqual(t, count, 3)
}
Step 2: Run integration tests
Run: go test -tags rocm -run TestROCm -v -timeout 120s
Expected: PASS (all 3 tests). Each test loads the model (~2-3s for 1B), generates tokens, and shuts down.
Step 3: Commit
git add rocm_integration_test.go
git commit -m "test: integration tests for full ROCm inference pipeline"
Task 8: Update TODO.md
Files:
- Modify:
TODO.md
Step 1: Mark Phase 1 tasks complete
Mark all 5 Phase 1 items as [x] with commit references and date.
Step 2: Commit
git add TODO.md
git commit -m "docs: mark Phase 1 tasks complete"
Build Order Summary
Task 1: health.go + test (unit, no GPU)
Task 2: client.go + test (unit, no GPU)
Task 3: client_test.go additions (unit, no GPU)
Task 4: server.go + test (unit, no GPU — helpers only)
Task 5: model.go (compiles)
Task 6: backend.go + test (unit, partial)
Task 7: integration test (GPU + model required)
Task 8: docs update
Tasks 1-6 can be verified without GPU hardware. Task 7 requires the RX 7800 XT and test model on SMB mount.