test: add integration tests for Classify, BatchGenerate, Info, Metrics
Verified on RX 7800 XT (gfx1100, ROCm 7.2): - Classify: greedy single-token via max_tokens=1 - BatchGenerate: sequential multi-prompt generation - Info: GGUF metadata (gemma3, 26 layers, Q5_K_M) - Metrics: 250 tok/s decode, 3244 MiB VRAM Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
parent
b03f357f5d
commit
402bdc2205
1 changed files with 94 additions and 0 deletions
|
|
@ -201,6 +201,100 @@ func TestROCm_ConcurrentRequests(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestROCm_Classify(t *testing.T) {
|
||||||
|
skipIfNoROCm(t)
|
||||||
|
skipIfNoModel(t)
|
||||||
|
|
||||||
|
b := &rocmBackend{}
|
||||||
|
m, err := b.LoadModel(testModel, inference.WithContextLen(2048))
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer m.Close()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
prompts := []string{
|
||||||
|
"The capital of France is",
|
||||||
|
"2 + 2 =",
|
||||||
|
}
|
||||||
|
|
||||||
|
results, err := m.Classify(ctx, prompts)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, results, 2)
|
||||||
|
|
||||||
|
for i, r := range results {
|
||||||
|
assert.NotEmpty(t, r.Token.Text, "classify result %d should have a token", i)
|
||||||
|
t.Logf("Classify %d: %q", i, r.Token.Text)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestROCm_BatchGenerate(t *testing.T) {
|
||||||
|
skipIfNoROCm(t)
|
||||||
|
skipIfNoModel(t)
|
||||||
|
|
||||||
|
b := &rocmBackend{}
|
||||||
|
m, err := b.LoadModel(testModel, inference.WithContextLen(2048))
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer m.Close()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
prompts := []string{
|
||||||
|
"The capital of France is",
|
||||||
|
"The capital of Germany is",
|
||||||
|
}
|
||||||
|
|
||||||
|
results, err := m.BatchGenerate(ctx, prompts, inference.WithMaxTokens(8))
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, results, 2)
|
||||||
|
|
||||||
|
for i, r := range results {
|
||||||
|
require.NoError(t, r.Err, "batch result %d error", i)
|
||||||
|
assert.NotEmpty(t, r.Tokens, "batch result %d should have tokens", i)
|
||||||
|
|
||||||
|
var sb strings.Builder
|
||||||
|
for _, tok := range r.Tokens {
|
||||||
|
sb.WriteString(tok.Text)
|
||||||
|
}
|
||||||
|
t.Logf("Batch %d: %s", i, sb.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestROCm_InfoAndMetrics(t *testing.T) {
|
||||||
|
skipIfNoROCm(t)
|
||||||
|
skipIfNoModel(t)
|
||||||
|
|
||||||
|
b := &rocmBackend{}
|
||||||
|
m, err := b.LoadModel(testModel, inference.WithContextLen(2048))
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer m.Close()
|
||||||
|
|
||||||
|
// Info should be populated from GGUF metadata.
|
||||||
|
info := m.Info()
|
||||||
|
assert.Equal(t, "gemma3", info.Architecture)
|
||||||
|
assert.Greater(t, info.NumLayers, 0, "expected non-zero layer count")
|
||||||
|
assert.Greater(t, info.QuantBits, 0, "expected non-zero quant bits")
|
||||||
|
t.Logf("Info: arch=%s layers=%d quant=%d-bit group=%d",
|
||||||
|
info.Architecture, info.NumLayers, info.QuantBits, info.QuantGroup)
|
||||||
|
|
||||||
|
// Generate some tokens to populate metrics.
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
for range m.Generate(ctx, "Hello", inference.WithMaxTokens(4)) {
|
||||||
|
}
|
||||||
|
require.NoError(t, m.Err())
|
||||||
|
|
||||||
|
met := m.Metrics()
|
||||||
|
assert.Greater(t, met.GeneratedTokens, 0, "expected generated tokens")
|
||||||
|
assert.Greater(t, met.TotalDuration, time.Duration(0), "expected non-zero duration")
|
||||||
|
assert.Greater(t, met.DecodeTokensPerSec, float64(0), "expected non-zero decode throughput")
|
||||||
|
t.Logf("Metrics: gen=%d tok, total=%s, decode=%.1f tok/s, vram=%d MiB",
|
||||||
|
met.GeneratedTokens, met.TotalDuration, met.DecodeTokensPerSec,
|
||||||
|
met.ActiveMemoryBytes/(1024*1024))
|
||||||
|
}
|
||||||
|
|
||||||
func TestROCm_DiscoverModels(t *testing.T) {
|
func TestROCm_DiscoverModels(t *testing.T) {
|
||||||
dir := filepath.Dir(testModel)
|
dir := filepath.Dir(testModel)
|
||||||
if _, err := os.Stat(dir); err != nil {
|
if _, err := os.Stat(dir); err != nil {
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue