307 lines
8.8 KiB
Go
307 lines
8.8 KiB
Go
|
|
// SPDX-Licence-Identifier: EUPL-1.2
|
||
|
|
|
||
|
|
package ml
|
||
|
|
|
||
|
|
import (
|
||
|
|
"context"
|
||
|
|
"encoding/json"
|
||
|
|
"net/http"
|
||
|
|
"net/http/httptest"
|
||
|
|
"testing"
|
||
|
|
|
||
|
|
"forge.lthn.ai/core/go-inference"
|
||
|
|
|
||
|
|
"github.com/stretchr/testify/assert"
|
||
|
|
"github.com/stretchr/testify/require"
|
||
|
|
)
|
||
|
|
|
||
|
|
// newTestServer creates an httptest.Server that responds with the given content.
|
||
|
|
func newTestServer(t *testing.T, content string) *httptest.Server {
|
||
|
|
t.Helper()
|
||
|
|
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
|
|
resp := chatResponse{
|
||
|
|
Choices: []chatChoice{{Message: Message{Role: "assistant", Content: content}}},
|
||
|
|
}
|
||
|
|
if err := json.NewEncoder(w).Encode(resp); err != nil {
|
||
|
|
t.Fatalf("encode response: %v", err)
|
||
|
|
}
|
||
|
|
}))
|
||
|
|
}
|
||
|
|
|
||
|
|
// newTestServerMulti creates an httptest.Server that responds based on the prompt.
|
||
|
|
func newTestServerMulti(t *testing.T) *httptest.Server {
|
||
|
|
t.Helper()
|
||
|
|
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
|
|
var req chatRequest
|
||
|
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||
|
|
t.Fatalf("decode request: %v", err)
|
||
|
|
}
|
||
|
|
// Echo back the last message content with a prefix.
|
||
|
|
lastContent := ""
|
||
|
|
if len(req.Messages) > 0 {
|
||
|
|
lastContent = req.Messages[len(req.Messages)-1].Content
|
||
|
|
}
|
||
|
|
resp := chatResponse{
|
||
|
|
Choices: []chatChoice{{Message: Message{Role: "assistant", Content: "reply:" + lastContent}}},
|
||
|
|
}
|
||
|
|
json.NewEncoder(w).Encode(resp)
|
||
|
|
}))
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestHTTPTextModel_Generate_Good(t *testing.T) {
|
||
|
|
srv := newTestServer(t, "Hello from HTTP")
|
||
|
|
defer srv.Close()
|
||
|
|
|
||
|
|
backend := NewHTTPBackend(srv.URL, "test-model")
|
||
|
|
model := NewHTTPTextModel(backend)
|
||
|
|
|
||
|
|
var collected []inference.Token
|
||
|
|
for tok := range model.Generate(context.Background(), "test prompt") {
|
||
|
|
collected = append(collected, tok)
|
||
|
|
}
|
||
|
|
|
||
|
|
require.Len(t, collected, 1)
|
||
|
|
assert.Equal(t, "Hello from HTTP", collected[0].Text)
|
||
|
|
assert.NoError(t, model.Err())
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestHTTPTextModel_Generate_WithOpts_Good(t *testing.T) {
|
||
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
|
|
var req chatRequest
|
||
|
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||
|
|
t.Fatalf("decode: %v", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
// Verify that options are passed through.
|
||
|
|
assert.InDelta(t, 0.8, req.Temperature, 0.01)
|
||
|
|
assert.Equal(t, 100, req.MaxTokens)
|
||
|
|
|
||
|
|
resp := chatResponse{
|
||
|
|
Choices: []chatChoice{{Message: Message{Role: "assistant", Content: "configured"}}},
|
||
|
|
}
|
||
|
|
json.NewEncoder(w).Encode(resp)
|
||
|
|
}))
|
||
|
|
defer srv.Close()
|
||
|
|
|
||
|
|
backend := NewHTTPBackend(srv.URL, "test-model")
|
||
|
|
model := NewHTTPTextModel(backend)
|
||
|
|
|
||
|
|
var result string
|
||
|
|
for tok := range model.Generate(context.Background(), "prompt",
|
||
|
|
inference.WithTemperature(0.8),
|
||
|
|
inference.WithMaxTokens(100),
|
||
|
|
) {
|
||
|
|
result = tok.Text
|
||
|
|
}
|
||
|
|
assert.Equal(t, "configured", result)
|
||
|
|
assert.NoError(t, model.Err())
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestHTTPTextModel_Chat_Good(t *testing.T) {
|
||
|
|
srv := newTestServer(t, "chat response")
|
||
|
|
defer srv.Close()
|
||
|
|
|
||
|
|
backend := NewHTTPBackend(srv.URL, "test-model")
|
||
|
|
model := NewHTTPTextModel(backend)
|
||
|
|
|
||
|
|
messages := []inference.Message{
|
||
|
|
{Role: "system", Content: "You are helpful."},
|
||
|
|
{Role: "user", Content: "Hello"},
|
||
|
|
}
|
||
|
|
|
||
|
|
var collected []inference.Token
|
||
|
|
for tok := range model.Chat(context.Background(), messages) {
|
||
|
|
collected = append(collected, tok)
|
||
|
|
}
|
||
|
|
|
||
|
|
require.Len(t, collected, 1)
|
||
|
|
assert.Equal(t, "chat response", collected[0].Text)
|
||
|
|
assert.NoError(t, model.Err())
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestHTTPTextModel_Generate_Error_Bad(t *testing.T) {
|
||
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
|
|
w.WriteHeader(http.StatusBadRequest)
|
||
|
|
w.Write([]byte("invalid request"))
|
||
|
|
}))
|
||
|
|
defer srv.Close()
|
||
|
|
|
||
|
|
backend := NewHTTPBackend(srv.URL, "test-model")
|
||
|
|
model := NewHTTPTextModel(backend)
|
||
|
|
|
||
|
|
var collected []inference.Token
|
||
|
|
for tok := range model.Generate(context.Background(), "bad prompt") {
|
||
|
|
collected = append(collected, tok)
|
||
|
|
}
|
||
|
|
|
||
|
|
assert.Empty(t, collected, "no tokens should be yielded on error")
|
||
|
|
assert.Error(t, model.Err())
|
||
|
|
assert.Contains(t, model.Err().Error(), "400")
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestHTTPTextModel_Chat_Error_Bad(t *testing.T) {
|
||
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
|
|
w.WriteHeader(http.StatusBadRequest)
|
||
|
|
w.Write([]byte("bad chat"))
|
||
|
|
}))
|
||
|
|
defer srv.Close()
|
||
|
|
|
||
|
|
backend := NewHTTPBackend(srv.URL, "test-model")
|
||
|
|
model := NewHTTPTextModel(backend)
|
||
|
|
|
||
|
|
messages := []inference.Message{{Role: "user", Content: "test"}}
|
||
|
|
var collected []inference.Token
|
||
|
|
for tok := range model.Chat(context.Background(), messages) {
|
||
|
|
collected = append(collected, tok)
|
||
|
|
}
|
||
|
|
|
||
|
|
assert.Empty(t, collected)
|
||
|
|
assert.Error(t, model.Err())
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestHTTPTextModel_Classify_Bad(t *testing.T) {
|
||
|
|
backend := NewHTTPBackend("http://localhost", "test-model")
|
||
|
|
model := NewHTTPTextModel(backend)
|
||
|
|
|
||
|
|
results, err := model.Classify(context.Background(), []string{"test"})
|
||
|
|
assert.Nil(t, results)
|
||
|
|
assert.Error(t, err)
|
||
|
|
assert.Contains(t, err.Error(), "classify not supported")
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestHTTPTextModel_BatchGenerate_Good(t *testing.T) {
|
||
|
|
srv := newTestServerMulti(t)
|
||
|
|
defer srv.Close()
|
||
|
|
|
||
|
|
backend := NewHTTPBackend(srv.URL, "test-model")
|
||
|
|
model := NewHTTPTextModel(backend)
|
||
|
|
|
||
|
|
prompts := []string{"alpha", "beta", "gamma"}
|
||
|
|
results, err := model.BatchGenerate(context.Background(), prompts)
|
||
|
|
require.NoError(t, err)
|
||
|
|
require.Len(t, results, 3)
|
||
|
|
|
||
|
|
assert.Equal(t, "reply:alpha", results[0].Tokens[0].Text)
|
||
|
|
assert.NoError(t, results[0].Err)
|
||
|
|
|
||
|
|
assert.Equal(t, "reply:beta", results[1].Tokens[0].Text)
|
||
|
|
assert.NoError(t, results[1].Err)
|
||
|
|
|
||
|
|
assert.Equal(t, "reply:gamma", results[2].Tokens[0].Text)
|
||
|
|
assert.NoError(t, results[2].Err)
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestHTTPTextModel_BatchGenerate_PartialError_Bad(t *testing.T) {
|
||
|
|
callCount := 0
|
||
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
|
|
callCount++
|
||
|
|
if callCount == 2 {
|
||
|
|
w.WriteHeader(http.StatusBadRequest)
|
||
|
|
w.Write([]byte("error on second"))
|
||
|
|
return
|
||
|
|
}
|
||
|
|
resp := chatResponse{
|
||
|
|
Choices: []chatChoice{{Message: Message{Role: "assistant", Content: "ok"}}},
|
||
|
|
}
|
||
|
|
json.NewEncoder(w).Encode(resp)
|
||
|
|
}))
|
||
|
|
defer srv.Close()
|
||
|
|
|
||
|
|
backend := NewHTTPBackend(srv.URL, "test-model")
|
||
|
|
model := NewHTTPTextModel(backend)
|
||
|
|
|
||
|
|
results, err := model.BatchGenerate(context.Background(), []string{"a", "b", "c"})
|
||
|
|
require.NoError(t, err) // BatchGenerate itself doesn't fail.
|
||
|
|
require.Len(t, results, 3)
|
||
|
|
|
||
|
|
assert.Len(t, results[0].Tokens, 1)
|
||
|
|
assert.NoError(t, results[0].Err)
|
||
|
|
|
||
|
|
assert.Empty(t, results[1].Tokens)
|
||
|
|
assert.Error(t, results[1].Err, "second prompt should have an error")
|
||
|
|
|
||
|
|
assert.Len(t, results[2].Tokens, 1)
|
||
|
|
assert.NoError(t, results[2].Err)
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestHTTPTextModel_ModelType_Good(t *testing.T) {
|
||
|
|
backend := NewHTTPBackend("http://localhost", "gpt-4o")
|
||
|
|
model := NewHTTPTextModel(backend)
|
||
|
|
assert.Equal(t, "gpt-4o", model.ModelType())
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestHTTPTextModel_ModelType_Empty_Good(t *testing.T) {
|
||
|
|
backend := NewHTTPBackend("http://localhost", "")
|
||
|
|
model := NewHTTPTextModel(backend)
|
||
|
|
assert.Equal(t, "http", model.ModelType())
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestHTTPTextModel_Info_Good(t *testing.T) {
|
||
|
|
backend := NewHTTPBackend("http://localhost", "test")
|
||
|
|
model := NewHTTPTextModel(backend)
|
||
|
|
info := model.Info()
|
||
|
|
assert.Equal(t, "http", info.Architecture)
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestHTTPTextModel_Metrics_Good(t *testing.T) {
|
||
|
|
backend := NewHTTPBackend("http://localhost", "test")
|
||
|
|
model := NewHTTPTextModel(backend)
|
||
|
|
metrics := model.Metrics()
|
||
|
|
assert.Equal(t, inference.GenerateMetrics{}, metrics)
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestHTTPTextModel_Err_ClearedOnSuccess_Good(t *testing.T) {
|
||
|
|
// First request fails.
|
||
|
|
callCount := 0
|
||
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
|
|
callCount++
|
||
|
|
if callCount == 1 {
|
||
|
|
w.WriteHeader(http.StatusBadRequest)
|
||
|
|
w.Write([]byte("fail"))
|
||
|
|
return
|
||
|
|
}
|
||
|
|
resp := chatResponse{
|
||
|
|
Choices: []chatChoice{{Message: Message{Role: "assistant", Content: "ok"}}},
|
||
|
|
}
|
||
|
|
json.NewEncoder(w).Encode(resp)
|
||
|
|
}))
|
||
|
|
defer srv.Close()
|
||
|
|
|
||
|
|
backend := NewHTTPBackend(srv.URL, "test-model")
|
||
|
|
model := NewHTTPTextModel(backend)
|
||
|
|
|
||
|
|
// First call: error.
|
||
|
|
for range model.Generate(context.Background(), "fail") {
|
||
|
|
}
|
||
|
|
assert.Error(t, model.Err())
|
||
|
|
|
||
|
|
// Second call: success — error should be cleared.
|
||
|
|
for range model.Generate(context.Background(), "ok") {
|
||
|
|
}
|
||
|
|
assert.NoError(t, model.Err())
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestHTTPTextModel_Close_Good(t *testing.T) {
|
||
|
|
backend := NewHTTPBackend("http://localhost", "test")
|
||
|
|
model := NewHTTPTextModel(backend)
|
||
|
|
assert.NoError(t, model.Close())
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestLlamaTextModel_ModelType_Good(t *testing.T) {
|
||
|
|
// LlamaBackend requires a process.Service but we only test ModelType here.
|
||
|
|
llama := &LlamaBackend{
|
||
|
|
http: NewHTTPBackend("http://127.0.0.1:18090", ""),
|
||
|
|
}
|
||
|
|
model := NewLlamaTextModel(llama)
|
||
|
|
assert.Equal(t, "llama", model.ModelType())
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestLlamaTextModel_Close_Good(t *testing.T) {
|
||
|
|
// LlamaBackend with no procID — Stop() is a no-op.
|
||
|
|
llama := &LlamaBackend{
|
||
|
|
http: NewHTTPBackend("http://127.0.0.1:18090", ""),
|
||
|
|
}
|
||
|
|
model := NewLlamaTextModel(llama)
|
||
|
|
assert.NoError(t, model.Close())
|
||
|
|
}
|