1
0
Fork 0
forked from lthn/LEM
LEM/pkg/lem/expand_test.go
Claude 70dd18c065
refactor: move Go library to pkg/lem, thin main.go
All scoring/influx/export/expand logic moves to pkg/lem as an
importable package. main.go is now a thin CLI dispatcher.

This lets new commands import the shared library directly —
ready for converting Python scripts to Go subcommands.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-15 16:30:09 +00:00

1028 lines
30 KiB
Go

package lem
import (
"bufio"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"sync/atomic"
"testing"
)
// mockExpandServer creates a combined httptest server that handles both
// OpenAI-compatible chat API and InfluxDB API endpoints.
// apiResponse is returned for /v1/chat/completions requests.
// influxQueryHandler handles /api/v3/query_sql requests.
// influxWriteHandler handles /api/v3/write_lp requests.
func mockExpandServer(t *testing.T, apiResponse string, influxQueryHandler func(q string) ([]map[string]interface{}, int), influxWriteHandler func(body string)) *httptest.Server {
t.Helper()
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case r.URL.Path == "/v1/chat/completions":
resp := ChatResponse{
Choices: []Choice{
{Message: Message{Role: "assistant", Content: apiResponse}},
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
case r.URL.Path == "/api/v3/query_sql":
var body struct {
DB string `json:"db"`
Q string `json:"q"`
}
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
w.WriteHeader(http.StatusBadRequest)
return
}
rows, status := influxQueryHandler(body.Q)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
json.NewEncoder(w).Encode(rows)
case r.URL.Path == "/api/v3/write_lp":
body, _ := io.ReadAll(r.Body)
if influxWriteHandler != nil {
influxWriteHandler(string(body))
}
w.WriteHeader(http.StatusOK)
default:
w.WriteHeader(http.StatusNotFound)
}
}))
}
func TestGetCompletedIDs(t *testing.T) {
t.Run("returns completed IDs from InfluxDB", func(t *testing.T) {
server := mockInfluxServer(t, func(q string) ([]map[string]interface{}, int) {
if !strings.Contains(q, "expansion_gen") {
t.Errorf("expected query against expansion_gen, got: %s", q)
}
return []map[string]interface{}{
{"seed_id": "prompt_001"},
{"seed_id": "prompt_002"},
{"seed_id": "prompt_003"},
}, http.StatusOK
})
defer server.Close()
t.Setenv("INFLUX_TOKEN", "test-token")
influx := NewInfluxClient(server.URL, "training")
ids, err := getCompletedIDs(influx)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(ids) != 3 {
t.Fatalf("expected 3 completed IDs, got %d", len(ids))
}
for _, id := range []string{"prompt_001", "prompt_002", "prompt_003"} {
if !ids[id] {
t.Errorf("expected ID %q to be in completed set", id)
}
}
})
t.Run("returns empty set when no completed IDs", func(t *testing.T) {
server := mockInfluxServer(t, func(q string) ([]map[string]interface{}, int) {
return []map[string]interface{}{}, http.StatusOK
})
defer server.Close()
t.Setenv("INFLUX_TOKEN", "test-token")
influx := NewInfluxClient(server.URL, "training")
ids, err := getCompletedIDs(influx)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(ids) != 0 {
t.Errorf("expected 0 completed IDs, got %d", len(ids))
}
})
t.Run("returns error on InfluxDB failure", func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte("database error"))
}))
defer server.Close()
t.Setenv("INFLUX_TOKEN", "test-token")
influx := NewInfluxClient(server.URL, "training")
_, err := getCompletedIDs(influx)
if err == nil {
t.Fatal("expected error on InfluxDB failure, got nil")
}
})
}
func TestExpandPromptsBasic(t *testing.T) {
var apiCalls atomic.Int32
var writtenLines []string
server := mockExpandServer(t,
"This is a generated response about ethics and sovereignty.",
func(q string) ([]map[string]interface{}, int) {
// No completed IDs
return []map[string]interface{}{}, http.StatusOK
},
func(body string) {
writtenLines = append(writtenLines, body)
},
)
defer server.Close()
// Override the api call counting
apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/v1/chat/completions" {
apiCalls.Add(1)
resp := ChatResponse{
Choices: []Choice{
{Message: Message{Role: "assistant", Content: "Generated response text."}},
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}
}))
defer apiServer.Close()
t.Setenv("INFLUX_TOKEN", "test-token")
influx := NewInfluxClient(server.URL, "training")
client := NewClient(apiServer.URL, "test-model")
client.MaxTokens = 2048
outputDir := t.TempDir()
prompts := []Response{
{ID: "p1", Domain: "ethics", Prompt: "What is ethics?"},
{ID: "p2", Domain: "sovereignty", Prompt: "Define sovereignty."},
{ID: "p3", Domain: "consent", Prompt: "Explain consent."},
}
err := expandPrompts(client, influx, nil, prompts, "test-model", "test-worker", outputDir, false)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// Should have called the API 3 times.
if got := apiCalls.Load(); got != 3 {
t.Errorf("expected 3 API calls, got %d", got)
}
// Output file should exist and contain 3 lines.
outputFile := filepath.Join(outputDir, "expand-test-worker.jsonl")
data, err := os.ReadFile(outputFile)
if err != nil {
t.Fatalf("read output file: %v", err)
}
lines := strings.Split(strings.TrimSpace(string(data)), "\n")
if len(lines) != 3 {
t.Fatalf("expected 3 output lines, got %d", len(lines))
}
// Parse each line and verify structure.
for i, line := range lines {
var r Response
if err := json.Unmarshal([]byte(line), &r); err != nil {
t.Fatalf("line %d: failed to parse JSON: %v", i, err)
}
if r.Response != "Generated response text." {
t.Errorf("line %d: expected response 'Generated response text.', got %q", i, r.Response)
}
if r.Model != "test-model" {
t.Errorf("line %d: expected model 'test-model', got %q", i, r.Model)
}
if r.ElapsedSeconds <= 0 {
t.Errorf("line %d: expected positive elapsed_seconds, got %f", i, r.ElapsedSeconds)
}
}
}
func TestExpandPromptsSkipsCompleted(t *testing.T) {
var apiCalls atomic.Int32
apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/v1/chat/completions" {
apiCalls.Add(1)
resp := ChatResponse{
Choices: []Choice{
{Message: Message{Role: "assistant", Content: "New response."}},
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}
}))
defer apiServer.Close()
// InfluxDB returns p1 and p2 as already completed.
influxServer := mockInfluxServer(t, func(q string) ([]map[string]interface{}, int) {
return []map[string]interface{}{
{"seed_id": "p1"},
{"seed_id": "p2"},
}, http.StatusOK
})
defer influxServer.Close()
t.Setenv("INFLUX_TOKEN", "test-token")
influx := NewInfluxClient(influxServer.URL, "training")
client := NewClient(apiServer.URL, "test-model")
client.MaxTokens = 2048
outputDir := t.TempDir()
prompts := []Response{
{ID: "p1", Domain: "ethics", Prompt: "What is ethics?"},
{ID: "p2", Domain: "sovereignty", Prompt: "Define sovereignty."},
{ID: "p3", Domain: "consent", Prompt: "Explain consent."},
}
err := expandPrompts(client, influx, nil, prompts, "test-model", "test-worker", outputDir, false)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// Only p3 should be processed.
if got := apiCalls.Load(); got != 1 {
t.Errorf("expected 1 API call (p3 only), got %d", got)
}
// Output should contain only 1 line.
outputFile := filepath.Join(outputDir, "expand-test-worker.jsonl")
data, err := os.ReadFile(outputFile)
if err != nil {
t.Fatalf("read output file: %v", err)
}
lines := strings.Split(strings.TrimSpace(string(data)), "\n")
if len(lines) != 1 {
t.Fatalf("expected 1 output line, got %d", len(lines))
}
var r Response
if err := json.Unmarshal([]byte(lines[0]), &r); err != nil {
t.Fatalf("parse output line: %v", err)
}
if r.ID != "p3" {
t.Errorf("expected ID 'p3', got %q", r.ID)
}
}
func TestExpandPromptsAllCompleted(t *testing.T) {
var apiCalls atomic.Int32
apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/v1/chat/completions" {
apiCalls.Add(1)
}
}))
defer apiServer.Close()
// All prompts already completed.
influxServer := mockInfluxServer(t, func(q string) ([]map[string]interface{}, int) {
return []map[string]interface{}{
{"seed_id": "p1"},
{"seed_id": "p2"},
}, http.StatusOK
})
defer influxServer.Close()
t.Setenv("INFLUX_TOKEN", "test-token")
influx := NewInfluxClient(influxServer.URL, "training")
client := NewClient(apiServer.URL, "test-model")
outputDir := t.TempDir()
prompts := []Response{
{ID: "p1", Domain: "ethics", Prompt: "What is ethics?"},
{ID: "p2", Domain: "sovereignty", Prompt: "Define sovereignty."},
}
err := expandPrompts(client, influx, nil, prompts, "test-model", "test-worker", outputDir, false)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// No API calls should have been made.
if got := apiCalls.Load(); got != 0 {
t.Errorf("expected 0 API calls, got %d", got)
}
// Output file should not exist.
outputFile := filepath.Join(outputDir, "expand-test-worker.jsonl")
if _, err := os.Stat(outputFile); !os.IsNotExist(err) {
t.Error("output file should not exist when all prompts are completed")
}
}
func TestExpandPromptsDryRun(t *testing.T) {
var apiCalls atomic.Int32
apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/v1/chat/completions" {
apiCalls.Add(1)
}
}))
defer apiServer.Close()
influxServer := mockInfluxServer(t, func(q string) ([]map[string]interface{}, int) {
return []map[string]interface{}{
{"seed_id": "p1"},
}, http.StatusOK
})
defer influxServer.Close()
t.Setenv("INFLUX_TOKEN", "test-token")
influx := NewInfluxClient(influxServer.URL, "training")
client := NewClient(apiServer.URL, "test-model")
outputDir := t.TempDir()
prompts := []Response{
{ID: "p1", Domain: "ethics", Prompt: "What is ethics?"},
{ID: "p2", Domain: "sovereignty", Prompt: "Define sovereignty."},
{ID: "p3", Domain: "consent", Prompt: "Explain consent."},
}
err := expandPrompts(client, influx, nil, prompts, "test-model", "test-worker", outputDir, true)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// No API calls should have been made in dry-run mode.
if got := apiCalls.Load(); got != 0 {
t.Errorf("expected 0 API calls in dry-run mode, got %d", got)
}
// Output file should not exist.
outputFile := filepath.Join(outputDir, "expand-test-worker.jsonl")
if _, err := os.Stat(outputFile); !os.IsNotExist(err) {
t.Error("output file should not exist in dry-run mode")
}
}
func TestExpandPromptsAPIErrorSkipsPrompt(t *testing.T) {
var apiCalls atomic.Int32
// First call fails with 400 (non-retryable), second call succeeds.
apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/v1/chat/completions" {
n := apiCalls.Add(1)
if n == 1 {
// First prompt fails with non-retryable error.
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte("bad request"))
return
}
resp := ChatResponse{
Choices: []Choice{
{Message: Message{Role: "assistant", Content: "Success response."}},
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}
}))
defer apiServer.Close()
influxServer := mockInfluxServer(t, func(q string) ([]map[string]interface{}, int) {
return []map[string]interface{}{}, http.StatusOK
})
defer influxServer.Close()
t.Setenv("INFLUX_TOKEN", "test-token")
influx := NewInfluxClient(influxServer.URL, "training")
client := NewClient(apiServer.URL, "test-model")
client.MaxTokens = 2048
outputDir := t.TempDir()
prompts := []Response{
{ID: "p1", Domain: "ethics", Prompt: "Will fail"},
{ID: "p2", Domain: "sovereignty", Prompt: "Will succeed"},
}
// Should NOT return an error — individual failures are logged and skipped.
err := expandPrompts(client, influx, nil, prompts, "test-model", "test-worker", outputDir, false)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// Output should contain only p2.
outputFile := filepath.Join(outputDir, "expand-test-worker.jsonl")
data, err := os.ReadFile(outputFile)
if err != nil {
t.Fatalf("read output file: %v", err)
}
lines := strings.Split(strings.TrimSpace(string(data)), "\n")
if len(lines) != 1 {
t.Fatalf("expected 1 output line (p2 only), got %d", len(lines))
}
var r Response
if err := json.Unmarshal([]byte(lines[0]), &r); err != nil {
t.Fatalf("parse output line: %v", err)
}
if r.ID != "p2" {
t.Errorf("expected ID 'p2', got %q", r.ID)
}
}
func TestExpandPromptsInfluxWriteErrorNonFatal(t *testing.T) {
// InfluxDB write failures should be logged but not crash the run.
apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/v1/chat/completions" {
resp := ChatResponse{
Choices: []Choice{
{Message: Message{Role: "assistant", Content: "Good response."}},
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}
}))
defer apiServer.Close()
// InfluxDB: query works but writes fail.
influxServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/v3/query_sql":
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode([]map[string]interface{}{})
case "/api/v3/write_lp":
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte("write failed"))
default:
w.WriteHeader(http.StatusNotFound)
}
}))
defer influxServer.Close()
t.Setenv("INFLUX_TOKEN", "test-token")
influx := NewInfluxClient(influxServer.URL, "training")
client := NewClient(apiServer.URL, "test-model")
client.MaxTokens = 2048
outputDir := t.TempDir()
prompts := []Response{
{ID: "p1", Domain: "ethics", Prompt: "What is ethics?"},
}
// Should succeed even though InfluxDB writes fail.
err := expandPrompts(client, influx, nil, prompts, "test-model", "test-worker", outputDir, false)
if err != nil {
t.Fatalf("unexpected error (InfluxDB write failure should be non-fatal): %v", err)
}
// Output file should still have the response.
outputFile := filepath.Join(outputDir, "expand-test-worker.jsonl")
data, err := os.ReadFile(outputFile)
if err != nil {
t.Fatalf("read output file: %v", err)
}
lines := strings.Split(strings.TrimSpace(string(data)), "\n")
if len(lines) != 1 {
t.Fatalf("expected 1 output line, got %d", len(lines))
}
}
func TestExpandPromptsOutputJSONLStructure(t *testing.T) {
apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/v1/chat/completions" {
resp := ChatResponse{
Choices: []Choice{
{Message: Message{Role: "assistant", Content: "A detailed response about consent and autonomy."}},
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}
}))
defer apiServer.Close()
influxServer := mockInfluxServer(t, func(q string) ([]map[string]interface{}, int) {
return []map[string]interface{}{}, http.StatusOK
})
defer influxServer.Close()
t.Setenv("INFLUX_TOKEN", "test-token")
influx := NewInfluxClient(influxServer.URL, "training")
client := NewClient(apiServer.URL, "test-model")
client.MaxTokens = 2048
outputDir := t.TempDir()
prompts := []Response{
{ID: "expand_42", Domain: "lek_consent", Prompt: "What does consent mean in AI ethics?"},
}
err := expandPrompts(client, influx, nil, prompts, "lem-gemma-3-1b", "snider-linux", outputDir, false)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
outputFile := filepath.Join(outputDir, "expand-snider-linux.jsonl")
data, err := os.ReadFile(outputFile)
if err != nil {
t.Fatalf("read output file: %v", err)
}
// Parse the JSONL output and verify all fields.
var r Response
if err := json.Unmarshal([]byte(strings.TrimSpace(string(data))), &r); err != nil {
t.Fatalf("parse output: %v", err)
}
if r.ID != "expand_42" {
t.Errorf("id = %q, want 'expand_42'", r.ID)
}
if r.Domain != "lek_consent" {
t.Errorf("domain = %q, want 'lek_consent'", r.Domain)
}
if r.Prompt != "What does consent mean in AI ethics?" {
t.Errorf("prompt = %q, want 'What does consent mean in AI ethics?'", r.Prompt)
}
if r.Response != "A detailed response about consent and autonomy." {
t.Errorf("response = %q, want 'A detailed response about consent and autonomy.'", r.Response)
}
if r.Model != "lem-gemma-3-1b" {
t.Errorf("model = %q, want 'lem-gemma-3-1b'", r.Model)
}
if r.ElapsedSeconds <= 0 {
t.Errorf("elapsed_seconds should be > 0, got %f", r.ElapsedSeconds)
}
}
func TestExpandPromptsInfluxLineProtocol(t *testing.T) {
var writtenBodies []string
apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/v1/chat/completions" {
resp := ChatResponse{
Choices: []Choice{
{Message: Message{Role: "assistant", Content: "Response text here."}},
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}
}))
defer apiServer.Close()
influxServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/v3/query_sql":
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode([]map[string]interface{}{})
case "/api/v3/write_lp":
body, _ := io.ReadAll(r.Body)
writtenBodies = append(writtenBodies, string(body))
w.WriteHeader(http.StatusOK)
default:
w.WriteHeader(http.StatusNotFound)
}
}))
defer influxServer.Close()
t.Setenv("INFLUX_TOKEN", "test-token")
influx := NewInfluxClient(influxServer.URL, "training")
client := NewClient(apiServer.URL, "test-model")
client.MaxTokens = 2048
outputDir := t.TempDir()
prompts := []Response{
{ID: "p1", Domain: "ethics", Prompt: "Test prompt"},
}
err := expandPrompts(client, influx, nil, prompts, "test-model", "my-worker", outputDir, false)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// Should have written at least one batch of line protocol data.
if len(writtenBodies) == 0 {
t.Fatal("expected at least one InfluxDB write, got none")
}
// Check that the line protocol contains expected measurements.
allWrites := strings.Join(writtenBodies, "\n")
if !strings.Contains(allWrites, "expansion_gen") {
t.Error("expected 'expansion_gen' measurement in InfluxDB writes")
}
if !strings.Contains(allWrites, "expansion_progress") {
t.Error("expected 'expansion_progress' measurement in InfluxDB writes")
}
if !strings.Contains(allWrites, `seed_id="p1"`) {
t.Error("expected seed_id=\"p1\" in InfluxDB writes")
}
if !strings.Contains(allWrites, "w=my-worker") {
t.Error("expected w=my-worker tag in InfluxDB writes")
}
if !strings.Contains(allWrites, "d=ethics") {
t.Error("expected d=ethics tag in InfluxDB writes")
}
if !strings.Contains(allWrites, `model="test-model"`) {
t.Error("expected model=\"test-model\" in InfluxDB writes")
}
}
func TestExpandPromptsAppendMode(t *testing.T) {
apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/v1/chat/completions" {
resp := ChatResponse{
Choices: []Choice{
{Message: Message{Role: "assistant", Content: "Appended response."}},
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}
}))
defer apiServer.Close()
influxServer := mockInfluxServer(t, func(q string) ([]map[string]interface{}, int) {
return []map[string]interface{}{}, http.StatusOK
})
defer influxServer.Close()
t.Setenv("INFLUX_TOKEN", "test-token")
influx := NewInfluxClient(influxServer.URL, "training")
client := NewClient(apiServer.URL, "test-model")
client.MaxTokens = 2048
outputDir := t.TempDir()
outputFile := filepath.Join(outputDir, "expand-test-worker.jsonl")
// Pre-write some content to the output file.
existingLine := `{"id":"existing","domain":"test","prompt":"pre-existing","response":"old","model":"old-model"}`
if err := os.WriteFile(outputFile, []byte(existingLine+"\n"), 0644); err != nil {
t.Fatalf("write existing file: %v", err)
}
prompts := []Response{
{ID: "p_new", Domain: "ethics", Prompt: "New prompt"},
}
err := expandPrompts(client, influx, nil, prompts, "test-model", "test-worker", outputDir, false)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// Read the output file and verify it has 2 lines (existing + new).
data, err := os.ReadFile(outputFile)
if err != nil {
t.Fatalf("read output file: %v", err)
}
scanner := bufio.NewScanner(strings.NewReader(strings.TrimSpace(string(data))))
var lineCount int
for scanner.Scan() {
lineCount++
}
if lineCount != 2 {
t.Errorf("expected 2 lines (existing + new), got %d", lineCount)
}
// First line should be the existing content.
lines := strings.Split(strings.TrimSpace(string(data)), "\n")
if lines[0] != existingLine {
t.Errorf("first line should be preserved, got: %s", lines[0])
}
// Second line should be the new response.
var r Response
if err := json.Unmarshal([]byte(lines[1]), &r); err != nil {
t.Fatalf("parse second line: %v", err)
}
if r.ID != "p_new" {
t.Errorf("second line ID = %q, want 'p_new'", r.ID)
}
}
func TestExpandPromptsLimit(t *testing.T) {
// Test that the limit parameter is applied within expandPrompts after filtering.
var apiCalls atomic.Int32
apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/v1/chat/completions" {
apiCalls.Add(1)
resp := ChatResponse{
Choices: []Choice{
{Message: Message{Role: "assistant", Content: "Limited response."}},
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}
}))
defer apiServer.Close()
influxServer := mockInfluxServer(t, func(q string) ([]map[string]interface{}, int) {
return []map[string]interface{}{}, http.StatusOK
})
defer influxServer.Close()
t.Setenv("INFLUX_TOKEN", "test-token")
influx := NewInfluxClient(influxServer.URL, "training")
client := NewClient(apiServer.URL, "test-model")
client.MaxTokens = 2048
outputDir := t.TempDir()
// Create 5 prompts but limit to 2 via the variadic limit parameter.
prompts := make([]Response, 5)
for i := range 5 {
prompts[i] = Response{
ID: fmt.Sprintf("p%d", i+1),
Domain: "test",
Prompt: fmt.Sprintf("Prompt %d", i+1),
}
}
// Pass all 5 prompts but limit to 2 via the variadic limit parameter.
err := expandPrompts(client, influx, nil, prompts, "test-model", "test-worker", outputDir, false, 2)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// Only 2 API calls should be made due to the limit.
if got := apiCalls.Load(); got != 2 {
t.Errorf("expected 2 API calls (limit=2), got %d", got)
}
// Output file should contain exactly 2 lines.
outputFile := filepath.Join(outputDir, "expand-test-worker.jsonl")
data, err := os.ReadFile(outputFile)
if err != nil {
t.Fatalf("read output file: %v", err)
}
lines := strings.Split(strings.TrimSpace(string(data)), "\n")
if len(lines) != 2 {
t.Fatalf("expected 2 output lines, got %d", len(lines))
}
// Verify the first 2 prompts were processed (p1, p2).
for i, line := range lines {
var r Response
if err := json.Unmarshal([]byte(line), &r); err != nil {
t.Fatalf("line %d: failed to parse JSON: %v", i, err)
}
expectedID := fmt.Sprintf("p%d", i+1)
if r.ID != expectedID {
t.Errorf("line %d: expected ID %q, got %q", i, expectedID, r.ID)
}
}
}
func TestExpandPromptsLimitAfterFiltering(t *testing.T) {
// Test that limit is applied AFTER filtering out completed IDs.
var apiCalls atomic.Int32
var processedIDs []string
apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/v1/chat/completions" {
apiCalls.Add(1)
resp := ChatResponse{
Choices: []Choice{
{Message: Message{Role: "assistant", Content: "Response."}},
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}
}))
defer apiServer.Close()
// p1 and p2 are already completed.
influxServer := mockInfluxServer(t, func(q string) ([]map[string]interface{}, int) {
return []map[string]interface{}{
{"seed_id": "p1"},
{"seed_id": "p2"},
}, http.StatusOK
})
defer influxServer.Close()
t.Setenv("INFLUX_TOKEN", "test-token")
influx := NewInfluxClient(influxServer.URL, "training")
client := NewClient(apiServer.URL, "test-model")
client.MaxTokens = 2048
outputDir := t.TempDir()
// 5 prompts, p1 and p2 completed, limit to 2 from remaining (p3, p4, p5).
prompts := make([]Response, 5)
for i := range 5 {
prompts[i] = Response{
ID: fmt.Sprintf("p%d", i+1),
Domain: "test",
Prompt: fmt.Sprintf("Prompt %d", i+1),
}
}
err := expandPrompts(client, influx, nil, prompts, "test-model", "test-worker", outputDir, false, 2)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// Should process only 2 of the 3 remaining (p3, p4).
if got := apiCalls.Load(); got != 2 {
t.Errorf("expected 2 API calls (limit=2 after filtering), got %d", got)
}
// Verify processed IDs are p3 and p4 (not p1, p2 which are completed).
outputFile := filepath.Join(outputDir, "expand-test-worker.jsonl")
data, err := os.ReadFile(outputFile)
if err != nil {
t.Fatalf("read output file: %v", err)
}
lines := strings.Split(strings.TrimSpace(string(data)), "\n")
for _, line := range lines {
var r Response
if err := json.Unmarshal([]byte(line), &r); err != nil {
t.Fatalf("parse output line: %v", err)
}
processedIDs = append(processedIDs, r.ID)
}
if len(processedIDs) != 2 {
t.Fatalf("expected 2 processed IDs, got %d", len(processedIDs))
}
if processedIDs[0] != "p3" {
t.Errorf("expected first processed ID to be 'p3', got %q", processedIDs[0])
}
if processedIDs[1] != "p4" {
t.Errorf("expected second processed ID to be 'p4', got %q", processedIDs[1])
}
}
func TestExpandPromptsLimitZeroMeansAll(t *testing.T) {
// Test that limit=0 means process all remaining prompts.
var apiCalls atomic.Int32
apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/v1/chat/completions" {
apiCalls.Add(1)
resp := ChatResponse{
Choices: []Choice{
{Message: Message{Role: "assistant", Content: "Response."}},
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}
}))
defer apiServer.Close()
influxServer := mockInfluxServer(t, func(q string) ([]map[string]interface{}, int) {
return []map[string]interface{}{}, http.StatusOK
})
defer influxServer.Close()
t.Setenv("INFLUX_TOKEN", "test-token")
influx := NewInfluxClient(influxServer.URL, "training")
client := NewClient(apiServer.URL, "test-model")
client.MaxTokens = 2048
outputDir := t.TempDir()
prompts := make([]Response, 3)
for i := range 3 {
prompts[i] = Response{
ID: fmt.Sprintf("p%d", i+1),
Domain: "test",
Prompt: fmt.Sprintf("Prompt %d", i+1),
}
}
// Explicitly pass limit=0 -- should process all.
err := expandPrompts(client, influx, nil, prompts, "test-model", "test-worker", outputDir, false, 0)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got := apiCalls.Load(); got != 3 {
t.Errorf("expected 3 API calls (limit=0 means all), got %d", got)
}
}
func TestExpandPromptsOutputHasCharsField(t *testing.T) {
// Verify the output JSONL includes the chars count in the Response struct
// (using existing fields, chars is encoded via the response length).
responseText := "This is exactly forty-seven characters long!"
apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/v1/chat/completions" {
resp := ChatResponse{
Choices: []Choice{
{Message: Message{Role: "assistant", Content: responseText}},
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}
}))
defer apiServer.Close()
influxServer := mockInfluxServer(t, func(q string) ([]map[string]interface{}, int) {
return []map[string]interface{}{}, http.StatusOK
})
defer influxServer.Close()
t.Setenv("INFLUX_TOKEN", "test-token")
influx := NewInfluxClient(influxServer.URL, "training")
client := NewClient(apiServer.URL, "test-model")
client.MaxTokens = 2048
outputDir := t.TempDir()
prompts := []Response{
{ID: "p1", Domain: "test", Prompt: "Test"},
}
err := expandPrompts(client, influx, nil, prompts, "test-model", "test-worker", outputDir, false)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// Read the output and check it contains the response.
outputFile := filepath.Join(outputDir, "expand-test-worker.jsonl")
data, err := os.ReadFile(outputFile)
if err != nil {
t.Fatalf("read output file: %v", err)
}
// Parse as raw JSON to check for the chars field.
var raw map[string]interface{}
if err := json.Unmarshal([]byte(strings.TrimSpace(string(data))), &raw); err != nil {
t.Fatalf("parse raw JSON: %v", err)
}
chars, ok := raw["chars"]
if !ok {
t.Fatal("output JSON missing 'chars' field")
}
charsVal, ok := chars.(float64)
if !ok {
t.Fatalf("chars field is not a number, got %T", chars)
}
if int(charsVal) != len(responseText) {
t.Errorf("chars = %d, want %d", int(charsVal), len(responseText))
}
}
func TestExpandPromptsGetCompletedIDsErrorNonFatal(t *testing.T) {
// If getCompletedIDs fails, expandPrompts should return an error
// (since we can't safely determine what's already done).
apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
resp := ChatResponse{
Choices: []Choice{
{Message: Message{Role: "assistant", Content: "Response."}},
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}))
defer apiServer.Close()
// InfluxDB query always fails.
influxServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte("query failed"))
}))
defer influxServer.Close()
t.Setenv("INFLUX_TOKEN", "test-token")
influx := NewInfluxClient(influxServer.URL, "training")
client := NewClient(apiServer.URL, "test-model")
outputDir := t.TempDir()
prompts := []Response{
{ID: "p1", Domain: "test", Prompt: "Test"},
}
err := expandPrompts(client, influx, nil, prompts, "test-model", "test-worker", outputDir, false)
if err == nil {
t.Fatal("expected error when getCompletedIDs fails")
}
}