feat: llamacpp SSE streaming client for chat completions
Add ChatComplete() and Complete() methods to the llamacpp Client, backed by a shared parseSSE() line parser. Types include ChatMessage, ChatRequest, CompletionRequest and their chunked response structs. Tests cover multi-chunk streaming, empty responses, HTTP errors, and context cancellation — all using httptest SSE servers. Co-Authored-By: Virgil <virgil@lethean.io> Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
d5a92c7212
commit
def3167199
2 changed files with 340 additions and 0 deletions
199
internal/llamacpp/client.go
Normal file
199
internal/llamacpp/client.go
Normal file
|
|
@ -0,0 +1,199 @@
|
|||
package llamacpp
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"iter"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ChatMessage is a single message in a conversation.
|
||||
type ChatMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
// ChatRequest is the request body for /v1/chat/completions.
|
||||
type ChatRequest struct {
|
||||
Messages []ChatMessage `json:"messages"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
Temperature float32 `json:"temperature"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
TopP float32 `json:"top_p,omitempty"`
|
||||
RepeatPenalty float32 `json:"repeat_penalty,omitempty"`
|
||||
Stream bool `json:"stream"`
|
||||
}
|
||||
|
||||
// CompletionRequest is the request body for /v1/completions.
|
||||
type CompletionRequest struct {
|
||||
Prompt string `json:"prompt"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
Temperature float32 `json:"temperature"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
TopP float32 `json:"top_p,omitempty"`
|
||||
RepeatPenalty float32 `json:"repeat_penalty,omitempty"`
|
||||
Stream bool `json:"stream"`
|
||||
}
|
||||
|
||||
type chatChunkResponse struct {
|
||||
Choices []struct {
|
||||
Delta struct {
|
||||
Content string `json:"content"`
|
||||
} `json:"delta"`
|
||||
FinishReason *string `json:"finish_reason"`
|
||||
} `json:"choices"`
|
||||
}
|
||||
|
||||
type completionChunkResponse struct {
|
||||
Choices []struct {
|
||||
Text string `json:"text"`
|
||||
FinishReason *string `json:"finish_reason"`
|
||||
} `json:"choices"`
|
||||
}
|
||||
|
||||
// ChatComplete sends a streaming chat completion request to /v1/chat/completions.
|
||||
// It returns an iterator over text chunks and a function that returns any error
|
||||
// that occurred during the request or while reading the stream.
|
||||
func (c *Client) ChatComplete(ctx context.Context, req ChatRequest) (iter.Seq[string], func() error) {
|
||||
req.Stream = true
|
||||
|
||||
body, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return noChunks, func() error { return fmt.Errorf("llamacpp: marshal chat request: %w", err) }
|
||||
}
|
||||
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+"/v1/chat/completions", bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return noChunks, func() error { return fmt.Errorf("llamacpp: create chat request: %w", err) }
|
||||
}
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := c.httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
return noChunks, func() error { return fmt.Errorf("llamacpp: chat request: %w", err) }
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
defer resp.Body.Close()
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 256))
|
||||
return noChunks, func() error {
|
||||
return fmt.Errorf("llamacpp: chat returned %d: %s", resp.StatusCode, strings.TrimSpace(string(respBody)))
|
||||
}
|
||||
}
|
||||
|
||||
var streamErr error
|
||||
sseData := parseSSE(resp.Body, &streamErr)
|
||||
|
||||
tokens := func(yield func(string) bool) {
|
||||
defer resp.Body.Close()
|
||||
for raw := range sseData {
|
||||
var chunk chatChunkResponse
|
||||
if err := json.Unmarshal([]byte(raw), &chunk); err != nil {
|
||||
streamErr = fmt.Errorf("llamacpp: decode chat chunk: %w", err)
|
||||
return
|
||||
}
|
||||
if len(chunk.Choices) == 0 {
|
||||
continue
|
||||
}
|
||||
text := chunk.Choices[0].Delta.Content
|
||||
if text == "" {
|
||||
continue
|
||||
}
|
||||
if !yield(text) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return tokens, func() error { return streamErr }
|
||||
}
|
||||
|
||||
// Complete sends a streaming completion request to /v1/completions.
|
||||
// It returns an iterator over text chunks and a function that returns any error
|
||||
// that occurred during the request or while reading the stream.
|
||||
func (c *Client) Complete(ctx context.Context, req CompletionRequest) (iter.Seq[string], func() error) {
|
||||
req.Stream = true
|
||||
|
||||
body, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return noChunks, func() error { return fmt.Errorf("llamacpp: marshal completion request: %w", err) }
|
||||
}
|
||||
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+"/v1/completions", bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return noChunks, func() error { return fmt.Errorf("llamacpp: create completion request: %w", err) }
|
||||
}
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := c.httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
return noChunks, func() error { return fmt.Errorf("llamacpp: completion request: %w", err) }
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
defer resp.Body.Close()
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 256))
|
||||
return noChunks, func() error {
|
||||
return fmt.Errorf("llamacpp: completion returned %d: %s", resp.StatusCode, strings.TrimSpace(string(respBody)))
|
||||
}
|
||||
}
|
||||
|
||||
var streamErr error
|
||||
sseData := parseSSE(resp.Body, &streamErr)
|
||||
|
||||
tokens := func(yield func(string) bool) {
|
||||
defer resp.Body.Close()
|
||||
for raw := range sseData {
|
||||
var chunk completionChunkResponse
|
||||
if err := json.Unmarshal([]byte(raw), &chunk); err != nil {
|
||||
streamErr = fmt.Errorf("llamacpp: decode completion chunk: %w", err)
|
||||
return
|
||||
}
|
||||
if len(chunk.Choices) == 0 {
|
||||
continue
|
||||
}
|
||||
text := chunk.Choices[0].Text
|
||||
if text == "" {
|
||||
continue
|
||||
}
|
||||
if !yield(text) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return tokens, func() error { return streamErr }
|
||||
}
|
||||
|
||||
// parseSSE reads SSE-formatted lines from r and yields the payload of each
|
||||
// "data: " line. It stops when it encounters "[DONE]" or an I/O error.
|
||||
// Any read error (other than EOF) is stored via errOut.
|
||||
func parseSSE(r io.Reader, errOut *error) iter.Seq[string] {
|
||||
return func(yield func(string) bool) {
|
||||
scanner := bufio.NewScanner(r)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if !strings.HasPrefix(line, "data: ") {
|
||||
continue
|
||||
}
|
||||
payload := strings.TrimPrefix(line, "data: ")
|
||||
if payload == "[DONE]" {
|
||||
return
|
||||
}
|
||||
if !yield(payload) {
|
||||
return
|
||||
}
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
*errOut = fmt.Errorf("llamacpp: read SSE stream: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// noChunks is an empty iterator returned when an error occurs before streaming begins.
|
||||
func noChunks(func(string) bool) {}
|
||||
141
internal/llamacpp/client_test.go
Normal file
141
internal/llamacpp/client_test.go
Normal file
|
|
@ -0,0 +1,141 @@
|
|||
package llamacpp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// sseLines writes SSE-formatted lines to a flushing response writer.
|
||||
func sseLines(w http.ResponseWriter, lines []string) {
|
||||
f, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
panic("ResponseWriter does not implement Flusher")
|
||||
}
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.Header().Set("Cache-Control", "no-cache")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
|
||||
for _, line := range lines {
|
||||
fmt.Fprintf(w, "data: %s\n\n", line)
|
||||
f.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
func TestChatComplete_Streaming(t *testing.T) {
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "/v1/chat/completions", r.URL.Path)
|
||||
assert.Equal(t, "POST", r.Method)
|
||||
sseLines(w, []string{
|
||||
`{"choices":[{"delta":{"content":"Hello"},"finish_reason":null}]}`,
|
||||
`{"choices":[{"delta":{"content":" world"},"finish_reason":null}]}`,
|
||||
`{"choices":[{"delta":{},"finish_reason":"stop"}]}`,
|
||||
"[DONE]",
|
||||
})
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
c := NewClient(ts.URL)
|
||||
tokens, errFn := c.ChatComplete(context.Background(), ChatRequest{
|
||||
Messages: []ChatMessage{{Role: "user", Content: "Hi"}},
|
||||
MaxTokens: 64,
|
||||
Temperature: 0.0,
|
||||
Stream: true,
|
||||
})
|
||||
|
||||
var got []string
|
||||
for tok := range tokens {
|
||||
got = append(got, tok)
|
||||
}
|
||||
require.NoError(t, errFn())
|
||||
assert.Equal(t, []string{"Hello", " world"}, got)
|
||||
}
|
||||
|
||||
func TestChatComplete_EmptyResponse(t *testing.T) {
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
sseLines(w, []string{"[DONE]"})
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
c := NewClient(ts.URL)
|
||||
tokens, errFn := c.ChatComplete(context.Background(), ChatRequest{
|
||||
Messages: []ChatMessage{{Role: "user", Content: "Hi"}},
|
||||
Temperature: 0.7,
|
||||
Stream: true,
|
||||
})
|
||||
|
||||
var got []string
|
||||
for tok := range tokens {
|
||||
got = append(got, tok)
|
||||
}
|
||||
require.NoError(t, errFn())
|
||||
assert.Empty(t, got)
|
||||
}
|
||||
|
||||
func TestChatComplete_HTTPError(t *testing.T) {
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
http.Error(w, "internal server error", http.StatusInternalServerError)
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
c := NewClient(ts.URL)
|
||||
tokens, errFn := c.ChatComplete(context.Background(), ChatRequest{
|
||||
Messages: []ChatMessage{{Role: "user", Content: "Hi"}},
|
||||
Temperature: 0.7,
|
||||
Stream: true,
|
||||
})
|
||||
|
||||
var got []string
|
||||
for tok := range tokens {
|
||||
got = append(got, tok)
|
||||
}
|
||||
assert.Empty(t, got)
|
||||
err := errFn()
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "500")
|
||||
}
|
||||
|
||||
func TestChatComplete_ContextCancelled(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
f, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
panic("ResponseWriter does not implement Flusher")
|
||||
}
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.Header().Set("Cache-Control", "no-cache")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
|
||||
// Send first chunk.
|
||||
fmt.Fprintf(w, "data: %s\n\n", `{"choices":[{"delta":{"content":"Hello"},"finish_reason":null}]}`)
|
||||
f.Flush()
|
||||
|
||||
// Wait for context cancellation before sending more.
|
||||
<-r.Context().Done()
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
c := NewClient(ts.URL)
|
||||
tokens, errFn := c.ChatComplete(ctx, ChatRequest{
|
||||
Messages: []ChatMessage{{Role: "user", Content: "Hi"}},
|
||||
Temperature: 0.7,
|
||||
Stream: true,
|
||||
})
|
||||
|
||||
var got []string
|
||||
for tok := range tokens {
|
||||
got = append(got, tok)
|
||||
cancel() // Cancel after receiving the first token.
|
||||
}
|
||||
// The error may or may not be nil depending on timing;
|
||||
// the important thing is we got exactly 1 token.
|
||||
_ = errFn()
|
||||
assert.Equal(t, []string{"Hello"}, got)
|
||||
}
|
||||
Loading…
Add table
Reference in a new issue