diff --git a/internal/llamacpp/client_test.go b/internal/llamacpp/client_test.go index 73c68a9..20b49c7 100644 --- a/internal/llamacpp/client_test.go +++ b/internal/llamacpp/client_test.go @@ -139,3 +139,56 @@ func TestChatComplete_ContextCancelled(t *testing.T) { _ = errFn() assert.Equal(t, []string{"Hello"}, got) } + +func TestComplete_Streaming(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "/v1/completions", r.URL.Path) + assert.Equal(t, "POST", r.Method) + sseLines(w, []string{ + `{"choices":[{"text":"Once","finish_reason":null}]}`, + `{"choices":[{"text":" upon","finish_reason":null}]}`, + `{"choices":[{"text":" a time","finish_reason":null}]}`, + `{"choices":[{"text":"","finish_reason":"stop"}]}`, + "[DONE]", + }) + })) + defer ts.Close() + + c := NewClient(ts.URL) + tokens, errFn := c.Complete(context.Background(), CompletionRequest{ + Prompt: "Once", + MaxTokens: 64, + Temperature: 0.0, + Stream: true, + }) + + var got []string + for tok := range tokens { + got = append(got, tok) + } + require.NoError(t, errFn()) + assert.Equal(t, []string{"Once", " upon", " a time"}, got) +} + +func TestComplete_HTTPError(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "bad request", http.StatusBadRequest) + })) + defer ts.Close() + + c := NewClient(ts.URL) + tokens, errFn := c.Complete(context.Background(), CompletionRequest{ + Prompt: "Hello", + Temperature: 0.7, + Stream: true, + }) + + var got []string + for tok := range tokens { + got = append(got, tok) + } + assert.Empty(t, got) + err := errFn() + require.Error(t, err) + assert.Contains(t, err.Error(), "400") +}