go-ml/backend_http_textmodel_test.go

307 lines
8.8 KiB
Go
Raw Permalink Normal View History

// SPDX-Licence-Identifier: EUPL-1.2
package ml
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"forge.lthn.ai/core/go-inference"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// newTestServer creates an httptest.Server that responds with the given content.
func newTestServer(t *testing.T, content string) *httptest.Server {
t.Helper()
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
resp := chatResponse{
Choices: []chatChoice{{Message: Message{Role: "assistant", Content: content}}},
}
if err := json.NewEncoder(w).Encode(resp); err != nil {
t.Fatalf("encode response: %v", err)
}
}))
}
// newTestServerMulti creates an httptest.Server that responds based on the prompt.
func newTestServerMulti(t *testing.T) *httptest.Server {
t.Helper()
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var req chatRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
t.Fatalf("decode request: %v", err)
}
// Echo back the last message content with a prefix.
lastContent := ""
if len(req.Messages) > 0 {
lastContent = req.Messages[len(req.Messages)-1].Content
}
resp := chatResponse{
Choices: []chatChoice{{Message: Message{Role: "assistant", Content: "reply:" + lastContent}}},
}
json.NewEncoder(w).Encode(resp)
}))
}
func TestHTTPTextModel_Generate_Good(t *testing.T) {
srv := newTestServer(t, "Hello from HTTP")
defer srv.Close()
backend := NewHTTPBackend(srv.URL, "test-model")
model := NewHTTPTextModel(backend)
var collected []inference.Token
for tok := range model.Generate(context.Background(), "test prompt") {
collected = append(collected, tok)
}
require.Len(t, collected, 1)
assert.Equal(t, "Hello from HTTP", collected[0].Text)
assert.NoError(t, model.Err())
}
func TestHTTPTextModel_Generate_WithOpts_Good(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var req chatRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
t.Fatalf("decode: %v", err)
}
// Verify that options are passed through.
assert.InDelta(t, 0.8, req.Temperature, 0.01)
assert.Equal(t, 100, req.MaxTokens)
resp := chatResponse{
Choices: []chatChoice{{Message: Message{Role: "assistant", Content: "configured"}}},
}
json.NewEncoder(w).Encode(resp)
}))
defer srv.Close()
backend := NewHTTPBackend(srv.URL, "test-model")
model := NewHTTPTextModel(backend)
var result string
for tok := range model.Generate(context.Background(), "prompt",
inference.WithTemperature(0.8),
inference.WithMaxTokens(100),
) {
result = tok.Text
}
assert.Equal(t, "configured", result)
assert.NoError(t, model.Err())
}
func TestHTTPTextModel_Chat_Good(t *testing.T) {
srv := newTestServer(t, "chat response")
defer srv.Close()
backend := NewHTTPBackend(srv.URL, "test-model")
model := NewHTTPTextModel(backend)
messages := []inference.Message{
{Role: "system", Content: "You are helpful."},
{Role: "user", Content: "Hello"},
}
var collected []inference.Token
for tok := range model.Chat(context.Background(), messages) {
collected = append(collected, tok)
}
require.Len(t, collected, 1)
assert.Equal(t, "chat response", collected[0].Text)
assert.NoError(t, model.Err())
}
func TestHTTPTextModel_Generate_Error_Bad(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte("invalid request"))
}))
defer srv.Close()
backend := NewHTTPBackend(srv.URL, "test-model")
model := NewHTTPTextModel(backend)
var collected []inference.Token
for tok := range model.Generate(context.Background(), "bad prompt") {
collected = append(collected, tok)
}
assert.Empty(t, collected, "no tokens should be yielded on error")
assert.Error(t, model.Err())
assert.Contains(t, model.Err().Error(), "400")
}
func TestHTTPTextModel_Chat_Error_Bad(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte("bad chat"))
}))
defer srv.Close()
backend := NewHTTPBackend(srv.URL, "test-model")
model := NewHTTPTextModel(backend)
messages := []inference.Message{{Role: "user", Content: "test"}}
var collected []inference.Token
for tok := range model.Chat(context.Background(), messages) {
collected = append(collected, tok)
}
assert.Empty(t, collected)
assert.Error(t, model.Err())
}
func TestHTTPTextModel_Classify_Bad(t *testing.T) {
backend := NewHTTPBackend("http://localhost", "test-model")
model := NewHTTPTextModel(backend)
results, err := model.Classify(context.Background(), []string{"test"})
assert.Nil(t, results)
assert.Error(t, err)
assert.Contains(t, err.Error(), "classify not supported")
}
func TestHTTPTextModel_BatchGenerate_Good(t *testing.T) {
srv := newTestServerMulti(t)
defer srv.Close()
backend := NewHTTPBackend(srv.URL, "test-model")
model := NewHTTPTextModel(backend)
prompts := []string{"alpha", "beta", "gamma"}
results, err := model.BatchGenerate(context.Background(), prompts)
require.NoError(t, err)
require.Len(t, results, 3)
assert.Equal(t, "reply:alpha", results[0].Tokens[0].Text)
assert.NoError(t, results[0].Err)
assert.Equal(t, "reply:beta", results[1].Tokens[0].Text)
assert.NoError(t, results[1].Err)
assert.Equal(t, "reply:gamma", results[2].Tokens[0].Text)
assert.NoError(t, results[2].Err)
}
func TestHTTPTextModel_BatchGenerate_PartialError_Bad(t *testing.T) {
callCount := 0
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
callCount++
if callCount == 2 {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte("error on second"))
return
}
resp := chatResponse{
Choices: []chatChoice{{Message: Message{Role: "assistant", Content: "ok"}}},
}
json.NewEncoder(w).Encode(resp)
}))
defer srv.Close()
backend := NewHTTPBackend(srv.URL, "test-model")
model := NewHTTPTextModel(backend)
results, err := model.BatchGenerate(context.Background(), []string{"a", "b", "c"})
require.NoError(t, err) // BatchGenerate itself doesn't fail.
require.Len(t, results, 3)
assert.Len(t, results[0].Tokens, 1)
assert.NoError(t, results[0].Err)
assert.Empty(t, results[1].Tokens)
assert.Error(t, results[1].Err, "second prompt should have an error")
assert.Len(t, results[2].Tokens, 1)
assert.NoError(t, results[2].Err)
}
func TestHTTPTextModel_ModelType_Good(t *testing.T) {
backend := NewHTTPBackend("http://localhost", "gpt-4o")
model := NewHTTPTextModel(backend)
assert.Equal(t, "gpt-4o", model.ModelType())
}
func TestHTTPTextModel_ModelType_Empty_Good(t *testing.T) {
backend := NewHTTPBackend("http://localhost", "")
model := NewHTTPTextModel(backend)
assert.Equal(t, "http", model.ModelType())
}
func TestHTTPTextModel_Info_Good(t *testing.T) {
backend := NewHTTPBackend("http://localhost", "test")
model := NewHTTPTextModel(backend)
info := model.Info()
assert.Equal(t, "http", info.Architecture)
}
func TestHTTPTextModel_Metrics_Good(t *testing.T) {
backend := NewHTTPBackend("http://localhost", "test")
model := NewHTTPTextModel(backend)
metrics := model.Metrics()
assert.Equal(t, inference.GenerateMetrics{}, metrics)
}
func TestHTTPTextModel_Err_ClearedOnSuccess_Good(t *testing.T) {
// First request fails.
callCount := 0
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
callCount++
if callCount == 1 {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte("fail"))
return
}
resp := chatResponse{
Choices: []chatChoice{{Message: Message{Role: "assistant", Content: "ok"}}},
}
json.NewEncoder(w).Encode(resp)
}))
defer srv.Close()
backend := NewHTTPBackend(srv.URL, "test-model")
model := NewHTTPTextModel(backend)
// First call: error.
for range model.Generate(context.Background(), "fail") {
}
assert.Error(t, model.Err())
// Second call: success — error should be cleared.
for range model.Generate(context.Background(), "ok") {
}
assert.NoError(t, model.Err())
}
func TestHTTPTextModel_Close_Good(t *testing.T) {
backend := NewHTTPBackend("http://localhost", "test")
model := NewHTTPTextModel(backend)
assert.NoError(t, model.Close())
}
func TestLlamaTextModel_ModelType_Good(t *testing.T) {
// LlamaBackend requires a process.Service but we only test ModelType here.
llama := &LlamaBackend{
http: NewHTTPBackend("http://127.0.0.1:18090", ""),
}
model := NewLlamaTextModel(llama)
assert.Equal(t, "llama", model.ModelType())
}
func TestLlamaTextModel_Close_Good(t *testing.T) {
// LlamaBackend with no procID — Stop() is a no-op.
llama := &LlamaBackend{
http: NewHTTPBackend("http://127.0.0.1:18090", ""),
}
model := NewLlamaTextModel(llama)
assert.NoError(t, model.Close())
}