From def3167199c0ed6cc381efeb5d20464dc6ca0288 Mon Sep 17 00:00:00 2001 From: Claude Date: Thu, 19 Feb 2026 20:58:46 +0000 Subject: [PATCH] feat: llamacpp SSE streaming client for chat completions MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add ChatComplete() and Complete() methods to the llamacpp Client, backed by a shared parseSSE() line parser. Types include ChatMessage, ChatRequest, CompletionRequest and their chunked response structs. Tests cover multi-chunk streaming, empty responses, HTTP errors, and context cancellation — all using httptest SSE servers. Co-Authored-By: Virgil Co-Authored-By: Claude Opus 4.6 --- internal/llamacpp/client.go | 199 +++++++++++++++++++++++++++++++ internal/llamacpp/client_test.go | 141 ++++++++++++++++++++++ 2 files changed, 340 insertions(+) create mode 100644 internal/llamacpp/client.go create mode 100644 internal/llamacpp/client_test.go diff --git a/internal/llamacpp/client.go b/internal/llamacpp/client.go new file mode 100644 index 0000000..82263ed --- /dev/null +++ b/internal/llamacpp/client.go @@ -0,0 +1,199 @@ +package llamacpp + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "iter" + "net/http" + "strings" +) + +// ChatMessage is a single message in a conversation. +type ChatMessage struct { + Role string `json:"role"` + Content string `json:"content"` +} + +// ChatRequest is the request body for /v1/chat/completions. +type ChatRequest struct { + Messages []ChatMessage `json:"messages"` + MaxTokens int `json:"max_tokens,omitempty"` + Temperature float32 `json:"temperature"` + TopK int `json:"top_k,omitempty"` + TopP float32 `json:"top_p,omitempty"` + RepeatPenalty float32 `json:"repeat_penalty,omitempty"` + Stream bool `json:"stream"` +} + +// CompletionRequest is the request body for /v1/completions. +type CompletionRequest struct { + Prompt string `json:"prompt"` + MaxTokens int `json:"max_tokens,omitempty"` + Temperature float32 `json:"temperature"` + TopK int `json:"top_k,omitempty"` + TopP float32 `json:"top_p,omitempty"` + RepeatPenalty float32 `json:"repeat_penalty,omitempty"` + Stream bool `json:"stream"` +} + +type chatChunkResponse struct { + Choices []struct { + Delta struct { + Content string `json:"content"` + } `json:"delta"` + FinishReason *string `json:"finish_reason"` + } `json:"choices"` +} + +type completionChunkResponse struct { + Choices []struct { + Text string `json:"text"` + FinishReason *string `json:"finish_reason"` + } `json:"choices"` +} + +// ChatComplete sends a streaming chat completion request to /v1/chat/completions. +// It returns an iterator over text chunks and a function that returns any error +// that occurred during the request or while reading the stream. +func (c *Client) ChatComplete(ctx context.Context, req ChatRequest) (iter.Seq[string], func() error) { + req.Stream = true + + body, err := json.Marshal(req) + if err != nil { + return noChunks, func() error { return fmt.Errorf("llamacpp: marshal chat request: %w", err) } + } + + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+"/v1/chat/completions", bytes.NewReader(body)) + if err != nil { + return noChunks, func() error { return fmt.Errorf("llamacpp: create chat request: %w", err) } + } + httpReq.Header.Set("Content-Type", "application/json") + + resp, err := c.httpClient.Do(httpReq) + if err != nil { + return noChunks, func() error { return fmt.Errorf("llamacpp: chat request: %w", err) } + } + + if resp.StatusCode != http.StatusOK { + defer resp.Body.Close() + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 256)) + return noChunks, func() error { + return fmt.Errorf("llamacpp: chat returned %d: %s", resp.StatusCode, strings.TrimSpace(string(respBody))) + } + } + + var streamErr error + sseData := parseSSE(resp.Body, &streamErr) + + tokens := func(yield func(string) bool) { + defer resp.Body.Close() + for raw := range sseData { + var chunk chatChunkResponse + if err := json.Unmarshal([]byte(raw), &chunk); err != nil { + streamErr = fmt.Errorf("llamacpp: decode chat chunk: %w", err) + return + } + if len(chunk.Choices) == 0 { + continue + } + text := chunk.Choices[0].Delta.Content + if text == "" { + continue + } + if !yield(text) { + return + } + } + } + + return tokens, func() error { return streamErr } +} + +// Complete sends a streaming completion request to /v1/completions. +// It returns an iterator over text chunks and a function that returns any error +// that occurred during the request or while reading the stream. +func (c *Client) Complete(ctx context.Context, req CompletionRequest) (iter.Seq[string], func() error) { + req.Stream = true + + body, err := json.Marshal(req) + if err != nil { + return noChunks, func() error { return fmt.Errorf("llamacpp: marshal completion request: %w", err) } + } + + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+"/v1/completions", bytes.NewReader(body)) + if err != nil { + return noChunks, func() error { return fmt.Errorf("llamacpp: create completion request: %w", err) } + } + httpReq.Header.Set("Content-Type", "application/json") + + resp, err := c.httpClient.Do(httpReq) + if err != nil { + return noChunks, func() error { return fmt.Errorf("llamacpp: completion request: %w", err) } + } + + if resp.StatusCode != http.StatusOK { + defer resp.Body.Close() + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 256)) + return noChunks, func() error { + return fmt.Errorf("llamacpp: completion returned %d: %s", resp.StatusCode, strings.TrimSpace(string(respBody))) + } + } + + var streamErr error + sseData := parseSSE(resp.Body, &streamErr) + + tokens := func(yield func(string) bool) { + defer resp.Body.Close() + for raw := range sseData { + var chunk completionChunkResponse + if err := json.Unmarshal([]byte(raw), &chunk); err != nil { + streamErr = fmt.Errorf("llamacpp: decode completion chunk: %w", err) + return + } + if len(chunk.Choices) == 0 { + continue + } + text := chunk.Choices[0].Text + if text == "" { + continue + } + if !yield(text) { + return + } + } + } + + return tokens, func() error { return streamErr } +} + +// parseSSE reads SSE-formatted lines from r and yields the payload of each +// "data: " line. It stops when it encounters "[DONE]" or an I/O error. +// Any read error (other than EOF) is stored via errOut. +func parseSSE(r io.Reader, errOut *error) iter.Seq[string] { + return func(yield func(string) bool) { + scanner := bufio.NewScanner(r) + for scanner.Scan() { + line := scanner.Text() + if !strings.HasPrefix(line, "data: ") { + continue + } + payload := strings.TrimPrefix(line, "data: ") + if payload == "[DONE]" { + return + } + if !yield(payload) { + return + } + } + if err := scanner.Err(); err != nil { + *errOut = fmt.Errorf("llamacpp: read SSE stream: %w", err) + } + } +} + +// noChunks is an empty iterator returned when an error occurs before streaming begins. +func noChunks(func(string) bool) {} diff --git a/internal/llamacpp/client_test.go b/internal/llamacpp/client_test.go new file mode 100644 index 0000000..73c68a9 --- /dev/null +++ b/internal/llamacpp/client_test.go @@ -0,0 +1,141 @@ +package llamacpp + +import ( + "context" + "fmt" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// sseLines writes SSE-formatted lines to a flushing response writer. +func sseLines(w http.ResponseWriter, lines []string) { + f, ok := w.(http.Flusher) + if !ok { + panic("ResponseWriter does not implement Flusher") + } + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.WriteHeader(http.StatusOK) + + for _, line := range lines { + fmt.Fprintf(w, "data: %s\n\n", line) + f.Flush() + } +} + +func TestChatComplete_Streaming(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "/v1/chat/completions", r.URL.Path) + assert.Equal(t, "POST", r.Method) + sseLines(w, []string{ + `{"choices":[{"delta":{"content":"Hello"},"finish_reason":null}]}`, + `{"choices":[{"delta":{"content":" world"},"finish_reason":null}]}`, + `{"choices":[{"delta":{},"finish_reason":"stop"}]}`, + "[DONE]", + }) + })) + defer ts.Close() + + c := NewClient(ts.URL) + tokens, errFn := c.ChatComplete(context.Background(), ChatRequest{ + Messages: []ChatMessage{{Role: "user", Content: "Hi"}}, + MaxTokens: 64, + Temperature: 0.0, + Stream: true, + }) + + var got []string + for tok := range tokens { + got = append(got, tok) + } + require.NoError(t, errFn()) + assert.Equal(t, []string{"Hello", " world"}, got) +} + +func TestChatComplete_EmptyResponse(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + sseLines(w, []string{"[DONE]"}) + })) + defer ts.Close() + + c := NewClient(ts.URL) + tokens, errFn := c.ChatComplete(context.Background(), ChatRequest{ + Messages: []ChatMessage{{Role: "user", Content: "Hi"}}, + Temperature: 0.7, + Stream: true, + }) + + var got []string + for tok := range tokens { + got = append(got, tok) + } + require.NoError(t, errFn()) + assert.Empty(t, got) +} + +func TestChatComplete_HTTPError(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "internal server error", http.StatusInternalServerError) + })) + defer ts.Close() + + c := NewClient(ts.URL) + tokens, errFn := c.ChatComplete(context.Background(), ChatRequest{ + Messages: []ChatMessage{{Role: "user", Content: "Hi"}}, + Temperature: 0.7, + Stream: true, + }) + + var got []string + for tok := range tokens { + got = append(got, tok) + } + assert.Empty(t, got) + err := errFn() + require.Error(t, err) + assert.Contains(t, err.Error(), "500") +} + +func TestChatComplete_ContextCancelled(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + f, ok := w.(http.Flusher) + if !ok { + panic("ResponseWriter does not implement Flusher") + } + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.WriteHeader(http.StatusOK) + + // Send first chunk. + fmt.Fprintf(w, "data: %s\n\n", `{"choices":[{"delta":{"content":"Hello"},"finish_reason":null}]}`) + f.Flush() + + // Wait for context cancellation before sending more. + <-r.Context().Done() + })) + defer ts.Close() + + c := NewClient(ts.URL) + tokens, errFn := c.ChatComplete(ctx, ChatRequest{ + Messages: []ChatMessage{{Role: "user", Content: "Hi"}}, + Temperature: 0.7, + Stream: true, + }) + + var got []string + for tok := range tokens { + got = append(got, tok) + cancel() // Cancel after receiving the first token. + } + // The error may or may not be nil depending on timing; + // the important thing is we got exactly 1 token. + _ = errFn() + assert.Equal(t, []string{"Hello"}, got) +}