go-rocm/docs/plans/completed/2026-02-19-phase1-implementation.md
Claude c0b7485129
Some checks failed
Security Scan / security (push) Successful in 10s
Test / Vet & Build (push) Failing after 29s
docs: archive completed phase 1-4 plans
All four phases fully implemented and tested on RX 7800 XT.

Co-Authored-By: Virgil <virgil@lethean.io>
2026-02-24 19:42:07 +00:00

30 KiB

Phase 1: Core Implementation Plan

For Claude: REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task.

Goal: Implement the full inference.Backend and inference.TextModel interfaces for AMD ROCm GPUs using llama-server as the subprocess engine.

Architecture: Bottom-up layered design. internal/llamacpp/ is a pure HTTP+SSE client with no go-inference dependency. server.go manages the llama-server process lifecycle. model.go implements inference.TextModel by wrapping the server and client. backend.go implements inference.Backend by wiring everything together.

Tech Stack: Go 1.25, llama-server (llama.cpp with HIP/ROCm), OpenAI-compatible API, SSE streaming, testify for assertions, httptest for mocking.

Critical environment note: HIP_VISIBLE_DEVICES=0 MUST be set when spawning llama-server. The Ryzen 9 9950X iGPU (Device 1) reports 100GB free memory (system RAM) and crashes llama-server when it tries to split tensors across devices.


Task 1: Health Check Client

Files:

  • Create: internal/llamacpp/health.go
  • Create: internal/llamacpp/health_test.go

Step 1: Create directory and write the failing test

Create internal/llamacpp/ directory, then write the test:

// internal/llamacpp/health_test.go
package llamacpp

import (
	"context"
	"net/http"
	"net/http/httptest"
	"testing"

	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
)

func TestHealth_OK(t *testing.T) {
	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		assert.Equal(t, "/health", r.URL.Path)
		w.Header().Set("Content-Type", "application/json")
		w.Write([]byte(`{"status":"ok"}`))
	}))
	defer ts.Close()

	c := NewClient(ts.URL)
	err := c.Health(context.Background())
	require.NoError(t, err)
}

func TestHealth_NotReady(t *testing.T) {
	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		w.Header().Set("Content-Type", "application/json")
		w.Write([]byte(`{"status":"loading model"}`))
	}))
	defer ts.Close()

	c := NewClient(ts.URL)
	err := c.Health(context.Background())
	assert.ErrorContains(t, err, "not ready")
}

func TestHealth_ServerDown(t *testing.T) {
	c := NewClient("http://127.0.0.1:1") // nothing listening
	err := c.Health(context.Background())
	assert.Error(t, err)
}

Step 2: Run test to verify it fails

Run: go test ./internal/llamacpp/ -run TestHealth -v Expected: FAIL — NewClient and Health not defined.

Step 3: Write minimal implementation

// internal/llamacpp/health.go
package llamacpp

import (
	"context"
	"encoding/json"
	"fmt"
	"net/http"
	"strings"
)

// Client communicates with a llama-server instance.
type Client struct {
	baseURL    string
	httpClient *http.Client
}

// NewClient creates a client for the llama-server at the given base URL.
func NewClient(baseURL string) *Client {
	return &Client{
		baseURL:    strings.TrimRight(baseURL, "/"),
		httpClient: &http.Client{},
	}
}

type healthResponse struct {
	Status string `json:"status"`
}

// Health checks whether the llama-server is ready to accept requests.
func (c *Client) Health(ctx context.Context) error {
	req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.baseURL+"/health", nil)
	if err != nil {
		return err
	}
	resp, err := c.httpClient.Do(req)
	if err != nil {
		return err
	}
	defer resp.Body.Close()

	if resp.StatusCode != http.StatusOK {
		return fmt.Errorf("llamacpp: health returned %d", resp.StatusCode)
	}
	var h healthResponse
	if err := json.NewDecoder(resp.Body).Decode(&h); err != nil {
		return fmt.Errorf("llamacpp: health decode: %w", err)
	}
	if h.Status != "ok" {
		return fmt.Errorf("llamacpp: server not ready (status: %s)", h.Status)
	}
	return nil
}

Step 4: Run test to verify it passes

Run: go test ./internal/llamacpp/ -run TestHealth -v Expected: PASS (all 3 tests)

Step 5: Commit

git add internal/llamacpp/health.go internal/llamacpp/health_test.go
git commit -m "feat: llamacpp health check client"

Task 2: SSE Streaming + Chat Completions

Files:

  • Create: internal/llamacpp/client.go
  • Create: internal/llamacpp/client_test.go

Step 1: Write the failing tests

// internal/llamacpp/client_test.go
package llamacpp

import (
	"context"
	"fmt"
	"net/http"
	"net/http/httptest"
	"testing"

	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
)

func TestChatComplete_Streaming(t *testing.T) {
	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		assert.Equal(t, "/v1/chat/completions", r.URL.Path)
		assert.Equal(t, "application/json", r.Header.Get("Content-Type"))

		flusher, ok := w.(http.Flusher)
		require.True(t, ok)

		w.Header().Set("Content-Type", "text/event-stream")
		w.WriteHeader(http.StatusOK)

		chunks := []string{
			`{"choices":[{"delta":{"content":"Hello"},"finish_reason":null}]}`,
			`{"choices":[{"delta":{"content":" world"},"finish_reason":null}]}`,
			`{"choices":[{"delta":{},"finish_reason":"stop"}]}`,
		}
		for _, c := range chunks {
			fmt.Fprintf(w, "data: %s\n\n", c)
			flusher.Flush()
		}
		fmt.Fprint(w, "data: [DONE]\n\n")
		flusher.Flush()
	}))
	defer ts.Close()

	c := NewClient(ts.URL)
	chunks, errFn := c.ChatComplete(context.Background(), ChatRequest{
		Messages:  []ChatMessage{{Role: "user", Content: "Hi"}},
		MaxTokens: 32,
	})

	var result []string
	for text := range chunks {
		result = append(result, text)
	}

	require.NoError(t, errFn())
	assert.Equal(t, []string{"Hello", " world"}, result)
}

func TestChatComplete_EmptyResponse(t *testing.T) {
	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		w.Header().Set("Content-Type", "text/event-stream")
		fmt.Fprint(w, "data: [DONE]\n\n")
	}))
	defer ts.Close()

	c := NewClient(ts.URL)
	chunks, errFn := c.ChatComplete(context.Background(), ChatRequest{
		Messages:  []ChatMessage{{Role: "user", Content: "Hi"}},
		MaxTokens: 1,
	})

	var result []string
	for text := range chunks {
		result = append(result, text)
	}

	require.NoError(t, errFn())
	assert.Empty(t, result)
}

func TestChatComplete_HTTPError(t *testing.T) {
	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		w.WriteHeader(http.StatusInternalServerError)
	}))
	defer ts.Close()

	c := NewClient(ts.URL)
	chunks, errFn := c.ChatComplete(context.Background(), ChatRequest{
		Messages: []ChatMessage{{Role: "user", Content: "Hi"}},
	})

	for range chunks {
		t.Fatal("expected no chunks")
	}

	assert.ErrorContains(t, errFn(), "500")
}

func TestChatComplete_ContextCancelled(t *testing.T) {
	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		flusher, _ := w.(http.Flusher)
		w.Header().Set("Content-Type", "text/event-stream")
		w.WriteHeader(http.StatusOK)

		// Send one chunk then block (context will cancel)
		fmt.Fprint(w, "data: {\"choices\":[{\"delta\":{\"content\":\"Hi\"},\"finish_reason\":null}]}\n\n")
		flusher.Flush()

		// Block until client disconnects
		<-r.Context().Done()
	}))
	defer ts.Close()

	ctx, cancel := context.WithCancel(context.Background())
	c := NewClient(ts.URL)
	chunks, _ := c.ChatComplete(ctx, ChatRequest{
		Messages: []ChatMessage{{Role: "user", Content: "Hi"}},
	})

	var got int
	for range chunks {
		got++
		cancel() // Cancel after first chunk
	}
	assert.Equal(t, 1, got)
}

Step 2: Run tests to verify they fail

Run: go test ./internal/llamacpp/ -run TestChat -v Expected: FAIL — ChatComplete, ChatRequest, ChatMessage not defined.

Step 3: Write the implementation

// internal/llamacpp/client.go
package llamacpp

import (
	"bufio"
	"bytes"
	"context"
	"encoding/json"
	"fmt"
	"io"
	"iter"
	"net/http"
	"strings"
)

// ChatMessage is a single message in a conversation.
type ChatMessage struct {
	Role    string `json:"role"`
	Content string `json:"content"`
}

// ChatRequest is the request body for /v1/chat/completions.
type ChatRequest struct {
	Messages      []ChatMessage `json:"messages"`
	MaxTokens     int           `json:"max_tokens,omitempty"`
	Temperature   float32       `json:"temperature"`
	TopK          int           `json:"top_k,omitempty"`
	TopP          float32       `json:"top_p,omitempty"`
	RepeatPenalty float32       `json:"repeat_penalty,omitempty"`
	Stream        bool          `json:"stream"`
}

// CompletionRequest is the request body for /v1/completions.
type CompletionRequest struct {
	Prompt        string  `json:"prompt"`
	MaxTokens     int     `json:"max_tokens,omitempty"`
	Temperature   float32 `json:"temperature"`
	TopK          int     `json:"top_k,omitempty"`
	TopP          float32 `json:"top_p,omitempty"`
	RepeatPenalty float32 `json:"repeat_penalty,omitempty"`
	Stream        bool    `json:"stream"`
}

type chatChunkResponse struct {
	Choices []struct {
		Delta struct {
			Content string `json:"content"`
		} `json:"delta"`
		FinishReason *string `json:"finish_reason"`
	} `json:"choices"`
}

type completionChunkResponse struct {
	Choices []struct {
		Text         string  `json:"text"`
		FinishReason *string `json:"finish_reason"`
	} `json:"choices"`
}

// ChatComplete sends a streaming chat completion request.
// Returns an iterator of text chunks and a function to retrieve any
// error that occurred during streaming. Call errFn after iteration.
func (c *Client) ChatComplete(ctx context.Context, req ChatRequest) (iter.Seq[string], func() error) {
	req.Stream = true
	body, err := json.Marshal(req)
	if err != nil {
		return noChunks, func() error { return fmt.Errorf("llamacpp: marshal: %w", err) }
	}

	httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+"/v1/chat/completions", bytes.NewReader(body))
	if err != nil {
		return noChunks, func() error { return err }
	}
	httpReq.Header.Set("Content-Type", "application/json")

	resp, err := c.httpClient.Do(httpReq)
	if err != nil {
		return noChunks, func() error { return err }
	}
	if resp.StatusCode != http.StatusOK {
		resp.Body.Close()
		return noChunks, func() error {
			return fmt.Errorf("llamacpp: chat request returned %d", resp.StatusCode)
		}
	}

	var streamErr error
	seq := func(yield func(string) bool) {
		defer resp.Body.Close()
		for payload := range parseSSE(resp.Body, &streamErr) {
			var chunk chatChunkResponse
			if err := json.Unmarshal([]byte(payload), &chunk); err != nil {
				streamErr = fmt.Errorf("llamacpp: parse chunk: %w", err)
				return
			}
			if len(chunk.Choices) > 0 && chunk.Choices[0].Delta.Content != "" {
				if !yield(chunk.Choices[0].Delta.Content) {
					return
				}
			}
		}
	}
	return seq, func() error { return streamErr }
}

// Complete sends a streaming text completion request.
// Returns an iterator of text chunks and a function to retrieve any
// error that occurred during streaming.
func (c *Client) Complete(ctx context.Context, req CompletionRequest) (iter.Seq[string], func() error) {
	req.Stream = true
	body, err := json.Marshal(req)
	if err != nil {
		return noChunks, func() error { return fmt.Errorf("llamacpp: marshal: %w", err) }
	}

	httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+"/v1/completions", bytes.NewReader(body))
	if err != nil {
		return noChunks, func() error { return err }
	}
	httpReq.Header.Set("Content-Type", "application/json")

	resp, err := c.httpClient.Do(httpReq)
	if err != nil {
		return noChunks, func() error { return err }
	}
	if resp.StatusCode != http.StatusOK {
		resp.Body.Close()
		return noChunks, func() error {
			return fmt.Errorf("llamacpp: completion request returned %d", resp.StatusCode)
		}
	}

	var streamErr error
	seq := func(yield func(string) bool) {
		defer resp.Body.Close()
		for payload := range parseSSE(resp.Body, &streamErr) {
			var chunk completionChunkResponse
			if err := json.Unmarshal([]byte(payload), &chunk); err != nil {
				streamErr = fmt.Errorf("llamacpp: parse chunk: %w", err)
				return
			}
			if len(chunk.Choices) > 0 && chunk.Choices[0].Text != "" {
				if !yield(chunk.Choices[0].Text) {
					return
				}
			}
		}
	}
	return seq, func() error { return streamErr }
}

// parseSSE reads SSE data lines from r, yielding each data payload.
// Stops on "data: [DONE]" or reader exhaustion. Sets *errOut on read error.
func parseSSE(r io.Reader, errOut *error) iter.Seq[string] {
	return func(yield func(string) bool) {
		scanner := bufio.NewScanner(r)
		for scanner.Scan() {
			line := scanner.Text()
			if !strings.HasPrefix(line, "data: ") {
				continue
			}
			payload := strings.TrimPrefix(line, "data: ")
			if payload == "[DONE]" {
				return
			}
			if !yield(payload) {
				return
			}
		}
		if err := scanner.Err(); err != nil {
			*errOut = fmt.Errorf("llamacpp: read stream: %w", err)
		}
	}
}

// noChunks is an empty iterator for error cases.
func noChunks(func(string) bool) {}

Step 4: Run tests to verify they pass

Run: go test ./internal/llamacpp/ -run TestChat -v Expected: PASS (all 4 tests)

Step 5: Commit

git add internal/llamacpp/client.go internal/llamacpp/client_test.go
git commit -m "feat: llamacpp SSE streaming client for chat completions"

Task 3: Text Completion Streaming Tests

Files:

  • Modify: internal/llamacpp/client_test.go (add completion tests)

The Complete method was already implemented in Task 2. This task adds its tests.

Step 1: Write the failing tests

Append to internal/llamacpp/client_test.go:

func TestComplete_Streaming(t *testing.T) {
	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		assert.Equal(t, "/v1/completions", r.URL.Path)

		flusher, ok := w.(http.Flusher)
		require.True(t, ok)

		w.Header().Set("Content-Type", "text/event-stream")
		w.WriteHeader(http.StatusOK)

		chunks := []string{
			`{"choices":[{"text":"Once","finish_reason":null}]}`,
			`{"choices":[{"text":" upon","finish_reason":null}]}`,
			`{"choices":[{"text":" a time","finish_reason":null}]}`,
			`{"choices":[{"text":"","finish_reason":"stop"}]}`,
		}
		for _, c := range chunks {
			fmt.Fprintf(w, "data: %s\n\n", c)
			flusher.Flush()
		}
		fmt.Fprint(w, "data: [DONE]\n\n")
		flusher.Flush()
	}))
	defer ts.Close()

	c := NewClient(ts.URL)
	chunks, errFn := c.Complete(context.Background(), CompletionRequest{
		Prompt:    "Once",
		MaxTokens: 32,
	})

	var result []string
	for text := range chunks {
		result = append(result, text)
	}

	require.NoError(t, errFn())
	assert.Equal(t, []string{"Once", " upon", " a time"}, result)
}

func TestComplete_HTTPError(t *testing.T) {
	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		w.WriteHeader(http.StatusBadRequest)
	}))
	defer ts.Close()

	c := NewClient(ts.URL)
	chunks, errFn := c.Complete(context.Background(), CompletionRequest{
		Prompt: "Hello",
	})

	for range chunks {
		t.Fatal("expected no chunks")
	}

	assert.ErrorContains(t, errFn(), "400")
}

Step 2: Run tests to verify they pass

Run: go test ./internal/llamacpp/ -run TestComplete -v Expected: PASS (both tests — implementation already exists from Task 2)

Step 3: Commit

git add internal/llamacpp/client_test.go
git commit -m "test: completion streaming tests for llamacpp client"

Task 4: Server Helpers

Files:

  • Create: server.go
  • Create: server_test.go

These are the utility functions that don't require a GPU: finding the llama-server binary and getting a free port.

Step 1: Write the failing tests

// server_test.go
//go:build linux && amd64

package rocm

import (
	"os"
	"strings"
	"testing"

	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
)

func TestFindLlamaServer_InPATH(t *testing.T) {
	// llama-server is installed at /usr/local/bin/llama-server on this machine
	path, err := findLlamaServer()
	require.NoError(t, err)
	assert.Contains(t, path, "llama-server")
}

func TestFindLlamaServer_EnvOverride(t *testing.T) {
	t.Setenv("ROCM_LLAMA_SERVER_PATH", "/usr/local/bin/llama-server")
	path, err := findLlamaServer()
	require.NoError(t, err)
	assert.Equal(t, "/usr/local/bin/llama-server", path)
}

func TestFindLlamaServer_EnvNotFound(t *testing.T) {
	t.Setenv("ROCM_LLAMA_SERVER_PATH", "/nonexistent/llama-server")
	_, err := findLlamaServer()
	assert.ErrorContains(t, err, "not found")
}

func TestFreePort(t *testing.T) {
	port, err := freePort()
	require.NoError(t, err)
	assert.Greater(t, port, 0)
	assert.Less(t, port, 65536)
}

func TestFreePort_UniquePerCall(t *testing.T) {
	p1, err := freePort()
	require.NoError(t, err)
	p2, err := freePort()
	require.NoError(t, err)
	// Not guaranteed but extremely likely with random kernel allocation
	_ = p1
	_ = p2
}

func TestServerEnv_HIPVisibleDevices(t *testing.T) {
	env := serverEnv()
	var hipVals []string
	for _, e := range env {
		if strings.HasPrefix(e, "HIP_VISIBLE_DEVICES=") {
			hipVals = append(hipVals, e)
		}
	}
	assert.Equal(t, []string{"HIP_VISIBLE_DEVICES=0"}, hipVals)
}

func TestServerEnv_FiltersExistingHIP(t *testing.T) {
	t.Setenv("HIP_VISIBLE_DEVICES", "1")
	env := serverEnv()
	var hipVals []string
	for _, e := range env {
		if strings.HasPrefix(e, "HIP_VISIBLE_DEVICES=") {
			hipVals = append(hipVals, e)
		}
	}
	// Must filter the old value and only have our override
	assert.Equal(t, []string{"HIP_VISIBLE_DEVICES=0"}, hipVals)
}

Step 2: Run tests to verify they fail

Run: go test -run "TestFind|TestFree|TestServer" -v Expected: FAIL — findLlamaServer, freePort, serverEnv not defined.

Step 3: Write minimal implementation

// server.go
//go:build linux && amd64

package rocm

import (
	"context"
	"fmt"
	"net"
	"os"
	"os/exec"
	"strconv"
	"strings"
	"syscall"
	"time"

	"forge.lthn.ai/core/go-rocm/internal/llamacpp"
)

// server manages a llama-server subprocess.
type server struct {
	cmd     *exec.Cmd
	port    int
	client  *llamacpp.Client
	exited  chan struct{}
	exitErr error
}

// findLlamaServer locates the llama-server binary.
// Checks ROCM_LLAMA_SERVER_PATH first, then PATH.
func findLlamaServer() (string, error) {
	if p := os.Getenv("ROCM_LLAMA_SERVER_PATH"); p != "" {
		if _, err := os.Stat(p); err == nil {
			return p, nil
		}
		return "", fmt.Errorf("rocm: ROCM_LLAMA_SERVER_PATH=%q not found", p)
	}
	p, err := exec.LookPath("llama-server")
	if err != nil {
		return "", fmt.Errorf("rocm: llama-server not in PATH: %w", err)
	}
	return p, nil
}

// freePort asks the kernel for a free TCP port on localhost.
func freePort() (int, error) {
	l, err := net.Listen("tcp", "127.0.0.1:0")
	if err != nil {
		return 0, err
	}
	port := l.Addr().(*net.TCPAddr).Port
	l.Close()
	return port, nil
}

// serverEnv returns the environment for the llama-server subprocess.
// Filters any existing HIP_VISIBLE_DEVICES and sets it to 0 to mask the iGPU.
func serverEnv() []string {
	env := os.Environ()
	filtered := make([]string, 0, len(env)+1)
	for _, e := range env {
		if !strings.HasPrefix(e, "HIP_VISIBLE_DEVICES=") {
			filtered = append(filtered, e)
		}
	}
	return append(filtered, "HIP_VISIBLE_DEVICES=0")
}

// startServer spawns llama-server and waits for it to become ready.
func startServer(binary, modelPath string, port, gpuLayers, ctxSize int) (*server, error) {
	ngl := gpuLayers
	if ngl < 0 {
		ngl = 999 // all layers
	}

	args := []string{
		"--model", modelPath,
		"--host", "127.0.0.1",
		"--port", strconv.Itoa(port),
		"--n-gpu-layers", strconv.Itoa(ngl),
	}
	if ctxSize > 0 {
		args = append(args, "--ctx-size", strconv.Itoa(ctxSize))
	}

	cmd := exec.Command(binary, args...)
	cmd.Env = serverEnv()

	if err := cmd.Start(); err != nil {
		return nil, fmt.Errorf("rocm: start llama-server: %w", err)
	}

	client := llamacpp.NewClient(fmt.Sprintf("http://127.0.0.1:%d", port))
	s := &server{
		cmd:    cmd,
		port:   port,
		client: client,
		exited: make(chan struct{}),
	}

	go func() {
		s.exitErr = cmd.Wait()
		close(s.exited)
	}()

	ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
	defer cancel()
	if err := s.waitReady(ctx); err != nil {
		s.stop()
		return nil, err
	}

	return s, nil
}

// waitReady polls the health endpoint until the server is ready.
func (s *server) waitReady(ctx context.Context) error {
	ticker := time.NewTicker(100 * time.Millisecond)
	defer ticker.Stop()
	for {
		select {
		case <-ctx.Done():
			return fmt.Errorf("rocm: llama-server health timeout: %w", ctx.Err())
		case <-s.exited:
			return fmt.Errorf("rocm: llama-server exited during startup: %v", s.exitErr)
		case <-ticker.C:
			if err := s.client.Health(ctx); err == nil {
				return nil
			}
		}
	}
}

// stop sends SIGTERM and waits up to 5s, then SIGKILL.
func (s *server) stop() error {
	if s.cmd.Process == nil {
		return nil
	}
	_ = s.cmd.Process.Signal(syscall.SIGTERM)
	select {
	case <-time.After(5 * time.Second):
		_ = s.cmd.Process.Kill()
		<-s.exited
	case <-s.exited:
	}
	return nil
}

Step 4: Run tests to verify they pass

Run: go test -run "TestFind|TestFree|TestServer" -v Expected: PASS (all tests). Note: TestFindLlamaServer_InPATH requires llama-server installed.

Step 5: Add testify dependency and commit

go get github.com/stretchr/testify
go mod tidy
git add server.go server_test.go go.mod go.sum
git commit -m "feat: server lifecycle and helpers for llama-server subprocess"

Task 5: TextModel Implementation

Files:

  • Create: model.go

No unit test for this file — it's a thin wrapper that maps inference types to llamacpp types. Tested via integration test in Task 7.

Step 1: Write the implementation

// model.go
//go:build linux && amd64

package rocm

import (
	"context"
	"iter"
	"sync"

	"forge.lthn.ai/core/go-inference"
	"forge.lthn.ai/core/go-rocm/internal/llamacpp"
)

// rocmModel implements inference.TextModel using a llama-server subprocess.
type rocmModel struct {
	srv       *server
	modelType string

	mu      sync.Mutex
	lastErr error
}

func (m *rocmModel) Generate(ctx context.Context, prompt string, opts ...inference.GenerateOption) iter.Seq[inference.Token] {
	cfg := inference.ApplyGenerateOpts(opts)

	req := llamacpp.CompletionRequest{
		Prompt:        prompt,
		MaxTokens:     cfg.MaxTokens,
		Temperature:   cfg.Temperature,
		TopK:          cfg.TopK,
		TopP:          cfg.TopP,
		RepeatPenalty: cfg.RepeatPenalty,
	}

	chunks, errFn := m.srv.client.Complete(ctx, req)
	return func(yield func(inference.Token) bool) {
		for text := range chunks {
			if !yield(inference.Token{Text: text}) {
				break
			}
		}
		if err := errFn(); err != nil {
			m.mu.Lock()
			m.lastErr = err
			m.mu.Unlock()
		}
	}
}

func (m *rocmModel) Chat(ctx context.Context, messages []inference.Message, opts ...inference.GenerateOption) iter.Seq[inference.Token] {
	cfg := inference.ApplyGenerateOpts(opts)

	chatMsgs := make([]llamacpp.ChatMessage, len(messages))
	for i, msg := range messages {
		chatMsgs[i] = llamacpp.ChatMessage{Role: msg.Role, Content: msg.Content}
	}

	req := llamacpp.ChatRequest{
		Messages:      chatMsgs,
		MaxTokens:     cfg.MaxTokens,
		Temperature:   cfg.Temperature,
		TopK:          cfg.TopK,
		TopP:          cfg.TopP,
		RepeatPenalty: cfg.RepeatPenalty,
	}

	chunks, errFn := m.srv.client.ChatComplete(ctx, req)
	return func(yield func(inference.Token) bool) {
		for text := range chunks {
			if !yield(inference.Token{Text: text}) {
				break
			}
		}
		if err := errFn(); err != nil {
			m.mu.Lock()
			m.lastErr = err
			m.mu.Unlock()
		}
	}
}

func (m *rocmModel) ModelType() string { return m.modelType }

func (m *rocmModel) Err() error {
	m.mu.Lock()
	defer m.mu.Unlock()
	return m.lastErr
}

func (m *rocmModel) Close() error {
	return m.srv.stop()
}

Step 2: Verify it compiles

Run: go build ./... Expected: No errors.

Step 3: Commit

git add model.go
git commit -m "feat: TextModel implementation wrapping llama-server"

Task 6: Backend Available() and LoadModel()

Files:

  • Modify: backend.go (replace stubs with real implementation)

Step 1: Update backend.go

Replace the entire file:

// backend.go
//go:build linux && amd64

package rocm

import (
	"fmt"
	"os"
	"path/filepath"
	"strings"

	"forge.lthn.ai/core/go-inference"
)

// rocmBackend implements inference.Backend for AMD ROCm GPUs.
type rocmBackend struct{}

func (b *rocmBackend) Name() string { return "rocm" }

func (b *rocmBackend) Available() bool {
	if _, err := os.Stat("/dev/kfd"); err != nil {
		return false
	}
	if _, err := findLlamaServer(); err != nil {
		return false
	}
	return true
}

func (b *rocmBackend) LoadModel(path string, opts ...inference.LoadOption) (inference.TextModel, error) {
	cfg := inference.ApplyLoadOpts(opts)

	binary, err := findLlamaServer()
	if err != nil {
		return nil, err
	}

	port, err := freePort()
	if err != nil {
		return nil, fmt.Errorf("rocm: find free port: %w", err)
	}

	srv, err := startServer(binary, path, port, cfg.GPULayers, cfg.ContextLen)
	if err != nil {
		return nil, err
	}

	return &rocmModel{
		srv:       srv,
		modelType: guessModelType(path),
	}, nil
}

// guessModelType extracts a model type hint from the filename.
// e.g. "LEK-Gemma3-4B-Q4_K_M.gguf" -> "gemma3"
func guessModelType(path string) string {
	name := strings.ToLower(filepath.Base(path))
	for _, arch := range []string{"gemma3", "gemma2", "gemma", "qwen3", "qwen2", "qwen", "llama3", "llama", "mistral", "phi"} {
		if strings.Contains(name, arch) {
			return arch
		}
	}
	return "unknown"
}

Step 2: Verify it compiles

Run: go build ./... Expected: No errors.

Step 3: Write a quick test for Available() and guessModelType()

Add to server_test.go:

func TestAvailable(t *testing.T) {
	b := &rocmBackend{}
	// On this machine (snider-linux): /dev/kfd exists, llama-server in PATH
	// This test will fail on machines without ROCm
	if _, err := os.Stat("/dev/kfd"); err != nil {
		t.Skip("no ROCm hardware")
	}
	assert.True(t, b.Available())
}

func TestGuessModelType(t *testing.T) {
	tests := []struct {
		path     string
		expected string
	}{
		{"/data/lem/gguf/LEK-Gemma3-4B-Q4_K_M.gguf", "gemma3"},
		{"/data/models/Qwen3-8B-Q4_K_M.gguf", "qwen3"},
		{"/data/models/Llama-3.1-8B-Q4_K_M.gguf", "llama3"},
		{"/data/models/Mistral-7B-v0.3-Q4_K_M.gguf", "mistral"},
		{"/data/models/random-model.gguf", "unknown"},
	}
	for _, tt := range tests {
		t.Run(tt.expected, func(t *testing.T) {
			assert.Equal(t, tt.expected, guessModelType(tt.path))
		})
	}
}

Step 4: Run tests

Run: go test -run "TestAvailable|TestGuessModel" -v Expected: PASS

Step 5: Commit

git add backend.go server_test.go
git commit -m "feat: Backend Available() and LoadModel() with GPU detection"

Task 7: Integration Test

Files:

  • Create: rocm_integration_test.go

This test loads a real model on the GPU, generates tokens, and verifies the full pipeline.

Step 1: Write the integration test

// rocm_integration_test.go
//go:build rocm

package rocm

import (
	"context"
	"os"
	"testing"
	"time"

	"forge.lthn.ai/core/go-inference"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
)

const testModel = "/data/lem/gguf/LEK-Gemma3-1B-layered-v2-Q5_K_M.gguf"

func skipIfNoModel(t *testing.T) {
	t.Helper()
	if _, err := os.Stat(testModel); os.IsNotExist(err) {
		t.Skip("test model not available (SMB mount down?)")
	}
}

func skipIfNoROCm(t *testing.T) {
	t.Helper()
	if _, err := os.Stat("/dev/kfd"); err != nil {
		t.Skip("no ROCm hardware")
	}
	if _, err := findLlamaServer(); err != nil {
		t.Skip("llama-server not found")
	}
}

func TestROCm_LoadAndGenerate(t *testing.T) {
	skipIfNoROCm(t)
	skipIfNoModel(t)

	b := &rocmBackend{}
	require.True(t, b.Available())

	m, err := b.LoadModel(testModel, inference.WithContextLen(2048))
	require.NoError(t, err)
	defer m.Close()

	assert.Equal(t, "gemma3", m.ModelType())

	// Generate with a simple prompt
	ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
	defer cancel()

	var tokens []string
	for tok := range m.Generate(ctx, "The capital of France is", inference.WithMaxTokens(16)) {
		tokens = append(tokens, tok.Text)
	}

	require.NoError(t, m.Err())
	require.NotEmpty(t, tokens, "expected at least one token")

	full := ""
	for _, tok := range tokens {
		full += tok
	}
	t.Logf("Generated: %s", full)
}

func TestROCm_Chat(t *testing.T) {
	skipIfNoROCm(t)
	skipIfNoModel(t)

	b := &rocmBackend{}
	m, err := b.LoadModel(testModel, inference.WithContextLen(2048))
	require.NoError(t, err)
	defer m.Close()

	ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
	defer cancel()

	messages := []inference.Message{
		{Role: "user", Content: "Say hello in exactly three words."},
	}

	var tokens []string
	for tok := range m.Chat(ctx, messages, inference.WithMaxTokens(32)) {
		tokens = append(tokens, tok.Text)
	}

	require.NoError(t, m.Err())
	require.NotEmpty(t, tokens, "expected at least one token")

	full := ""
	for _, tok := range tokens {
		full += tok
	}
	t.Logf("Chat response: %s", full)
}

func TestROCm_ContextCancellation(t *testing.T) {
	skipIfNoROCm(t)
	skipIfNoModel(t)

	b := &rocmBackend{}
	m, err := b.LoadModel(testModel, inference.WithContextLen(2048))
	require.NoError(t, err)
	defer m.Close()

	ctx, cancel := context.WithCancel(context.Background())

	var count int
	for tok := range m.Generate(ctx, "Write a very long story about dragons", inference.WithMaxTokens(256)) {
		_ = tok
		count++
		if count >= 3 {
			cancel()
		}
	}

	t.Logf("Got %d tokens before cancel", count)
	assert.GreaterOrEqual(t, count, 3)
}

Step 2: Run integration tests

Run: go test -tags rocm -run TestROCm -v -timeout 120s Expected: PASS (all 3 tests). Each test loads the model (~2-3s for 1B), generates tokens, and shuts down.

Step 3: Commit

git add rocm_integration_test.go
git commit -m "test: integration tests for full ROCm inference pipeline"

Task 8: Update TODO.md

Files:

  • Modify: TODO.md

Step 1: Mark Phase 1 tasks complete

Mark all 5 Phase 1 items as [x] with commit references and date.

Step 2: Commit

git add TODO.md
git commit -m "docs: mark Phase 1 tasks complete"

Build Order Summary

Task 1: health.go + test          (unit, no GPU)
Task 2: client.go + test          (unit, no GPU)
Task 3: client_test.go additions  (unit, no GPU)
Task 4: server.go + test          (unit, no GPU — helpers only)
Task 5: model.go                  (compiles)
Task 6: backend.go + test         (unit, partial)
Task 7: integration test          (GPU + model required)
Task 8: docs update

Tasks 1-6 can be verified without GPU hardware. Task 7 requires the RX 7800 XT and test model on SMB mount.