diff --git a/internal/llamacpp/client.go b/internal/llamacpp/client.go index 82263ed..bc2577d 100644 --- a/internal/llamacpp/client.go +++ b/internal/llamacpp/client.go @@ -10,6 +10,7 @@ import ( "iter" "net/http" "strings" + "sync" ) // ChatMessage is a single message in a conversation. @@ -72,6 +73,7 @@ func (c *Client) ChatComplete(ctx context.Context, req ChatRequest) (iter.Seq[st return noChunks, func() error { return fmt.Errorf("llamacpp: create chat request: %w", err) } } httpReq.Header.Set("Content-Type", "application/json") + httpReq.Header.Set("Accept", "text/event-stream") resp, err := c.httpClient.Do(httpReq) if err != nil { @@ -86,11 +88,15 @@ func (c *Client) ChatComplete(ctx context.Context, req ChatRequest) (iter.Seq[st } } - var streamErr error + var ( + streamErr error + closeOnce sync.Once + closeBody = func() { closeOnce.Do(func() { resp.Body.Close() }) } + ) sseData := parseSSE(resp.Body, &streamErr) tokens := func(yield func(string) bool) { - defer resp.Body.Close() + defer closeBody() for raw := range sseData { var chunk chatChunkResponse if err := json.Unmarshal([]byte(raw), &chunk); err != nil { @@ -110,7 +116,10 @@ func (c *Client) ChatComplete(ctx context.Context, req ChatRequest) (iter.Seq[st } } - return tokens, func() error { return streamErr } + return tokens, func() error { + closeBody() + return streamErr + } } // Complete sends a streaming completion request to /v1/completions. @@ -129,6 +138,7 @@ func (c *Client) Complete(ctx context.Context, req CompletionRequest) (iter.Seq[ return noChunks, func() error { return fmt.Errorf("llamacpp: create completion request: %w", err) } } httpReq.Header.Set("Content-Type", "application/json") + httpReq.Header.Set("Accept", "text/event-stream") resp, err := c.httpClient.Do(httpReq) if err != nil { @@ -143,11 +153,15 @@ func (c *Client) Complete(ctx context.Context, req CompletionRequest) (iter.Seq[ } } - var streamErr error + var ( + streamErr error + closeOnce sync.Once + closeBody = func() { closeOnce.Do(func() { resp.Body.Close() }) } + ) sseData := parseSSE(resp.Body, &streamErr) tokens := func(yield func(string) bool) { - defer resp.Body.Close() + defer closeBody() for raw := range sseData { var chunk completionChunkResponse if err := json.Unmarshal([]byte(raw), &chunk); err != nil { @@ -167,7 +181,10 @@ func (c *Client) Complete(ctx context.Context, req CompletionRequest) (iter.Seq[ } } - return tokens, func() error { return streamErr } + return tokens, func() error { + closeBody() + return streamErr + } } // parseSSE reads SSE-formatted lines from r and yields the payload of each