go-ml/backend_mlx_test.go
Snider 7601614e07 test: update all test assertions for Result type
All Backend.Generate/Chat calls now return Result. Test assertions
updated to use .Text and check .Metrics where appropriate.

Co-Authored-By: Virgil <virgil@lethean.io>
2026-02-22 17:45:32 +00:00

157 lines
4.9 KiB
Go

// SPDX-Licence-Identifier: EUPL-1.2
//go:build darwin && arm64
package ml
import (
"context"
"testing"
"forge.lthn.ai/core/go-inference"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// ---------------------------------------------------------------------------
// backend_mlx.go tests — uses mockTextModel from adapter_test.go
// since we cannot load real MLX models in CI
// ---------------------------------------------------------------------------
// TestMLXBackend_InferenceAdapter_Generate_Good verifies that an
// InferenceAdapter (the type returned by NewMLXBackend) correctly
// generates text through a mock TextModel.
func TestMLXBackend_InferenceAdapter_Generate_Good(t *testing.T) {
mock := &mockTextModel{
tokens: []inference.Token{
{ID: 1, Text: "MLX "},
{ID: 2, Text: "output"},
},
modelType: "qwen3",
}
adapter := NewInferenceAdapter(mock, "mlx")
// The adapter should satisfy Backend.
var backend Backend = adapter
assert.Equal(t, "mlx", backend.Name())
assert.True(t, backend.Available())
result, err := backend.Generate(context.Background(), "prompt", GenOpts{Temperature: 0.5})
require.NoError(t, err)
assert.Equal(t, "MLX output", result.Text)
assert.NotNil(t, result.Metrics)
}
// TestMLXBackend_InferenceAdapter_Chat_Good verifies chat through the
// InferenceAdapter wrapper (the path NewMLXBackend takes).
func TestMLXBackend_InferenceAdapter_Chat_Good(t *testing.T) {
mock := &mockTextModel{
tokens: []inference.Token{
{ID: 1, Text: "chat "},
{ID: 2, Text: "reply"},
},
}
adapter := NewInferenceAdapter(mock, "mlx")
messages := []Message{
{Role: "user", Content: "hello"},
}
result, err := adapter.Chat(context.Background(), messages, GenOpts{})
require.NoError(t, err)
assert.Equal(t, "chat reply", result.Text)
assert.NotNil(t, result.Metrics)
}
// TestMLXBackend_InferenceAdapter_Stream_Good verifies streaming through
// the InferenceAdapter (StreamingBackend path).
func TestMLXBackend_InferenceAdapter_Stream_Good(t *testing.T) {
mock := &mockTextModel{
tokens: []inference.Token{
{ID: 1, Text: "tok1"},
{ID: 2, Text: "tok2"},
{ID: 3, Text: "tok3"},
},
}
adapter := NewInferenceAdapter(mock, "mlx")
// Verify StreamingBackend compliance.
var streaming StreamingBackend = adapter
var collected []string
err := streaming.GenerateStream(context.Background(), "prompt", GenOpts{}, func(tok string) error {
collected = append(collected, tok)
return nil
})
require.NoError(t, err)
assert.Equal(t, []string{"tok1", "tok2", "tok3"}, collected)
}
// TestMLXBackend_InferenceAdapter_ModelError_Bad verifies error propagation
// from the underlying TextModel through InferenceAdapter (the MLX path).
func TestMLXBackend_InferenceAdapter_ModelError_Bad(t *testing.T) {
mock := &mockTextModel{
tokens: []inference.Token{
{ID: 1, Text: "partial"},
},
err: assert.AnError,
modelType: "qwen3",
}
adapter := NewInferenceAdapter(mock, "mlx")
result, err := adapter.Generate(context.Background(), "prompt", GenOpts{})
assert.Error(t, err)
assert.Equal(t, "partial", result.Text, "partial output should still be returned")
assert.Nil(t, result.Metrics, "metrics should be nil on error")
}
// TestMLXBackend_InferenceAdapter_Close_Good verifies that Close delegates
// to the underlying TextModel.
func TestMLXBackend_InferenceAdapter_Close_Good(t *testing.T) {
mock := &mockTextModel{}
adapter := NewInferenceAdapter(mock, "mlx")
err := adapter.Close()
require.NoError(t, err)
assert.True(t, mock.closed)
}
// TestMLXBackend_InferenceAdapter_ModelAccess_Good verifies that the
// underlying TextModel is accessible for direct operations.
func TestMLXBackend_InferenceAdapter_ModelAccess_Good(t *testing.T) {
mock := &mockTextModel{modelType: "llama"}
adapter := NewInferenceAdapter(mock, "mlx")
model := adapter.Model()
assert.Equal(t, "llama", model.ModelType())
assert.Equal(t, inference.ModelInfo{}, model.Info())
}
// TestMLXBackend_InterfaceCompliance_Good verifies that InferenceAdapter
// (the return type of NewMLXBackend) satisfies both Backend and
// StreamingBackend at compile time.
func TestMLXBackend_InterfaceCompliance_Good(t *testing.T) {
var _ Backend = (*InferenceAdapter)(nil)
var _ StreamingBackend = (*InferenceAdapter)(nil)
}
// TestMLXBackend_ConvertOpts_Temperature_Good verifies that GenOpts
// Temperature maps correctly through the adapter (critical for MLX
// which is temperature-sensitive on Metal).
func TestMLXBackend_ConvertOpts_Temperature_Good(t *testing.T) {
opts := convertOpts(GenOpts{Temperature: 0.8, MaxTokens: 2048})
assert.Len(t, opts, 2)
}
// TestMLXBackend_ConvertOpts_AllFields_Good verifies all GenOpts fields
// produce the expected number of inference options.
func TestMLXBackend_ConvertOpts_AllFields_Good(t *testing.T) {
opts := convertOpts(GenOpts{
Temperature: 0.7,
MaxTokens: 512,
TopK: 40,
TopP: 0.9,
RepeatPenalty: 1.1,
})
assert.Len(t, opts, 5)
}