484 lines
14 KiB
Go
484 lines
14 KiB
Go
|
|
package main
|
||
|
|
|
||
|
|
import (
|
||
|
|
"bufio"
|
||
|
|
"encoding/json"
|
||
|
|
"os"
|
||
|
|
"path/filepath"
|
||
|
|
"strings"
|
||
|
|
"testing"
|
||
|
|
)
|
||
|
|
|
||
|
|
func TestFilterResponses(t *testing.T) {
|
||
|
|
tests := []struct {
|
||
|
|
name string
|
||
|
|
input []Response
|
||
|
|
want int
|
||
|
|
}{
|
||
|
|
{
|
||
|
|
name: "empty input",
|
||
|
|
input: []Response{},
|
||
|
|
want: 0,
|
||
|
|
},
|
||
|
|
{
|
||
|
|
name: "all valid",
|
||
|
|
input: []Response{
|
||
|
|
{ID: "1", Prompt: "hello", Response: strings.Repeat("a", 50), Model: "m"},
|
||
|
|
{ID: "2", Prompt: "world", Response: strings.Repeat("b", 100), Model: "m"},
|
||
|
|
},
|
||
|
|
want: 2,
|
||
|
|
},
|
||
|
|
{
|
||
|
|
name: "filter empty response",
|
||
|
|
input: []Response{
|
||
|
|
{ID: "1", Prompt: "hello", Response: "", Model: "m"},
|
||
|
|
{ID: "2", Prompt: "world", Response: strings.Repeat("b", 50), Model: "m"},
|
||
|
|
},
|
||
|
|
want: 1,
|
||
|
|
},
|
||
|
|
{
|
||
|
|
name: "filter error prefix",
|
||
|
|
input: []Response{
|
||
|
|
{ID: "1", Prompt: "hello", Response: "ERROR: something went wrong", Model: "m"},
|
||
|
|
{ID: "2", Prompt: "world", Response: strings.Repeat("b", 50), Model: "m"},
|
||
|
|
},
|
||
|
|
want: 1,
|
||
|
|
},
|
||
|
|
{
|
||
|
|
name: "filter short response under 50 chars",
|
||
|
|
input: []Response{
|
||
|
|
{ID: "1", Prompt: "hello", Response: strings.Repeat("a", 49), Model: "m"},
|
||
|
|
{ID: "2", Prompt: "world", Response: strings.Repeat("b", 50), Model: "m"},
|
||
|
|
},
|
||
|
|
want: 1,
|
||
|
|
},
|
||
|
|
{
|
||
|
|
name: "filter all bad",
|
||
|
|
input: []Response{
|
||
|
|
{ID: "1", Prompt: "p1", Response: "", Model: "m"},
|
||
|
|
{ID: "2", Prompt: "p2", Response: "ERROR: fail", Model: "m"},
|
||
|
|
{ID: "3", Prompt: "p3", Response: "too short", Model: "m"},
|
||
|
|
},
|
||
|
|
want: 0,
|
||
|
|
},
|
||
|
|
{
|
||
|
|
name: "exactly 50 chars passes",
|
||
|
|
input: []Response{
|
||
|
|
{ID: "1", Prompt: "hello", Response: strings.Repeat("x", 50), Model: "m"},
|
||
|
|
},
|
||
|
|
want: 1,
|
||
|
|
},
|
||
|
|
{
|
||
|
|
name: "ERROR prefix is case sensitive",
|
||
|
|
input: []Response{
|
||
|
|
{ID: "1", Prompt: "hello", Response: strings.Repeat("error: lowercase is fine and long enough to pass", 2), Model: "m"},
|
||
|
|
},
|
||
|
|
want: 1,
|
||
|
|
},
|
||
|
|
}
|
||
|
|
|
||
|
|
for _, tt := range tests {
|
||
|
|
t.Run(tt.name, func(t *testing.T) {
|
||
|
|
got := filterResponses(tt.input)
|
||
|
|
if len(got) != tt.want {
|
||
|
|
t.Errorf("filterResponses() returned %d responses, want %d", len(got), tt.want)
|
||
|
|
}
|
||
|
|
})
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestSplitData(t *testing.T) {
|
||
|
|
// Create 100 responses for easy percentage calculation.
|
||
|
|
responses := make([]Response, 100)
|
||
|
|
for i := 0; i < 100; i++ {
|
||
|
|
responses[i] = Response{ID: "r" + string(rune('0'+i/10)) + string(rune('0'+i%10))}
|
||
|
|
}
|
||
|
|
|
||
|
|
tests := []struct {
|
||
|
|
name string
|
||
|
|
trainPct, validPct, testPct int
|
||
|
|
wantTrain, wantValid, wantTest int
|
||
|
|
}{
|
||
|
|
{
|
||
|
|
name: "default 90/5/5",
|
||
|
|
trainPct: 90,
|
||
|
|
validPct: 5,
|
||
|
|
testPct: 5,
|
||
|
|
wantTrain: 90,
|
||
|
|
wantValid: 5,
|
||
|
|
wantTest: 5,
|
||
|
|
},
|
||
|
|
{
|
||
|
|
name: "80/10/10",
|
||
|
|
trainPct: 80,
|
||
|
|
validPct: 10,
|
||
|
|
testPct: 10,
|
||
|
|
wantTrain: 80,
|
||
|
|
wantValid: 10,
|
||
|
|
wantTest: 10,
|
||
|
|
},
|
||
|
|
{
|
||
|
|
name: "100/0/0",
|
||
|
|
trainPct: 100,
|
||
|
|
validPct: 0,
|
||
|
|
testPct: 0,
|
||
|
|
wantTrain: 100,
|
||
|
|
wantValid: 0,
|
||
|
|
wantTest: 0,
|
||
|
|
},
|
||
|
|
}
|
||
|
|
|
||
|
|
for _, tt := range tests {
|
||
|
|
t.Run(tt.name, func(t *testing.T) {
|
||
|
|
train, valid, test := splitData(responses, tt.trainPct, tt.validPct, tt.testPct, 42)
|
||
|
|
if len(train) != tt.wantTrain {
|
||
|
|
t.Errorf("train size = %d, want %d", len(train), tt.wantTrain)
|
||
|
|
}
|
||
|
|
if len(valid) != tt.wantValid {
|
||
|
|
t.Errorf("valid size = %d, want %d", len(valid), tt.wantValid)
|
||
|
|
}
|
||
|
|
if len(test) != tt.wantTest {
|
||
|
|
t.Errorf("test size = %d, want %d", len(test), tt.wantTest)
|
||
|
|
}
|
||
|
|
})
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestSplitDataDeterministic(t *testing.T) {
|
||
|
|
responses := make([]Response, 20)
|
||
|
|
for i := range responses {
|
||
|
|
responses[i] = Response{ID: "r" + string(rune('A'+i))}
|
||
|
|
}
|
||
|
|
|
||
|
|
// Same seed should produce same split.
|
||
|
|
train1, valid1, test1 := splitData(responses, 80, 10, 10, 42)
|
||
|
|
train2, valid2, test2 := splitData(responses, 80, 10, 10, 42)
|
||
|
|
|
||
|
|
for i := range train1 {
|
||
|
|
if train1[i].ID != train2[i].ID {
|
||
|
|
t.Errorf("train[%d]: got %s and %s with same seed", i, train1[i].ID, train2[i].ID)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
for i := range valid1 {
|
||
|
|
if valid1[i].ID != valid2[i].ID {
|
||
|
|
t.Errorf("valid[%d]: got %s and %s with same seed", i, valid1[i].ID, valid2[i].ID)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
for i := range test1 {
|
||
|
|
if test1[i].ID != test2[i].ID {
|
||
|
|
t.Errorf("test[%d]: got %s and %s with same seed", i, test1[i].ID, test2[i].ID)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestSplitDataDifferentSeed(t *testing.T) {
|
||
|
|
responses := make([]Response, 50)
|
||
|
|
for i := range responses {
|
||
|
|
responses[i] = Response{ID: "r" + string(rune('A'+i%26)) + string(rune('0'+i/26))}
|
||
|
|
}
|
||
|
|
|
||
|
|
train1, _, _ := splitData(responses, 80, 10, 10, 42)
|
||
|
|
train2, _, _ := splitData(responses, 80, 10, 10, 99)
|
||
|
|
|
||
|
|
// Different seeds should (almost certainly) produce different orderings.
|
||
|
|
different := false
|
||
|
|
for i := range train1 {
|
||
|
|
if train1[i].ID != train2[i].ID {
|
||
|
|
different = true
|
||
|
|
break
|
||
|
|
}
|
||
|
|
}
|
||
|
|
if !different {
|
||
|
|
t.Error("different seeds produced identical orderings, expected different")
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestSplitDataRemainder(t *testing.T) {
|
||
|
|
// 7 items with 90/5/5: train=6, valid=0, test=0 — remainder goes to test.
|
||
|
|
// Actually: train = 7*90/100 = 6, valid = 7*5/100 = 0, test = 7 - 6 - 0 = 1.
|
||
|
|
responses := make([]Response, 7)
|
||
|
|
for i := range responses {
|
||
|
|
responses[i] = Response{ID: "r"}
|
||
|
|
}
|
||
|
|
|
||
|
|
train, valid, test := splitData(responses, 90, 5, 5, 42)
|
||
|
|
total := len(train) + len(valid) + len(test)
|
||
|
|
if total != 7 {
|
||
|
|
t.Errorf("total split size = %d, want 7", total)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestWriteTrainingJSONL(t *testing.T) {
|
||
|
|
dir := t.TempDir()
|
||
|
|
path := filepath.Join(dir, "train.jsonl")
|
||
|
|
|
||
|
|
responses := []Response{
|
||
|
|
{ID: "1", Prompt: "What is ethics?", Response: "Ethics is the study of moral principles.", Model: "m"},
|
||
|
|
{ID: "2", Prompt: "Define AI.", Response: "Artificial Intelligence is a field of computer science.", Model: "m"},
|
||
|
|
}
|
||
|
|
|
||
|
|
if err := writeTrainingJSONL(path, responses); err != nil {
|
||
|
|
t.Fatalf("writeTrainingJSONL() error: %v", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
// Read back and verify.
|
||
|
|
f, err := os.Open(path)
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("failed to open output: %v", err)
|
||
|
|
}
|
||
|
|
defer f.Close()
|
||
|
|
|
||
|
|
scanner := bufio.NewScanner(f)
|
||
|
|
var examples []TrainingExample
|
||
|
|
for scanner.Scan() {
|
||
|
|
line := strings.TrimSpace(scanner.Text())
|
||
|
|
if line == "" {
|
||
|
|
continue
|
||
|
|
}
|
||
|
|
var ex TrainingExample
|
||
|
|
if err := json.Unmarshal([]byte(line), &ex); err != nil {
|
||
|
|
t.Fatalf("failed to unmarshal line: %v", err)
|
||
|
|
}
|
||
|
|
examples = append(examples, ex)
|
||
|
|
}
|
||
|
|
|
||
|
|
if len(examples) != 2 {
|
||
|
|
t.Fatalf("expected 2 examples, got %d", len(examples))
|
||
|
|
}
|
||
|
|
|
||
|
|
// Verify first example.
|
||
|
|
if len(examples[0].Messages) != 2 {
|
||
|
|
t.Fatalf("expected 2 messages, got %d", len(examples[0].Messages))
|
||
|
|
}
|
||
|
|
if examples[0].Messages[0].Role != "user" {
|
||
|
|
t.Errorf("messages[0].role = %q, want %q", examples[0].Messages[0].Role, "user")
|
||
|
|
}
|
||
|
|
if examples[0].Messages[0].Content != "What is ethics?" {
|
||
|
|
t.Errorf("messages[0].content = %q, want %q", examples[0].Messages[0].Content, "What is ethics?")
|
||
|
|
}
|
||
|
|
if examples[0].Messages[1].Role != "assistant" {
|
||
|
|
t.Errorf("messages[1].role = %q, want %q", examples[0].Messages[1].Role, "assistant")
|
||
|
|
}
|
||
|
|
if examples[0].Messages[1].Content != "Ethics is the study of moral principles." {
|
||
|
|
t.Errorf("messages[1].content = %q, want %q", examples[0].Messages[1].Content, "Ethics is the study of moral principles.")
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestWriteTrainingJSONLEmpty(t *testing.T) {
|
||
|
|
dir := t.TempDir()
|
||
|
|
path := filepath.Join(dir, "empty.jsonl")
|
||
|
|
|
||
|
|
if err := writeTrainingJSONL(path, []Response{}); err != nil {
|
||
|
|
t.Fatalf("writeTrainingJSONL() error: %v", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
data, err := os.ReadFile(path)
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("failed to read output: %v", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
if len(strings.TrimSpace(string(data))) != 0 {
|
||
|
|
t.Errorf("expected empty file, got %q", string(data))
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestWriteTrainingJSONLCreatesFile(t *testing.T) {
|
||
|
|
dir := t.TempDir()
|
||
|
|
subdir := filepath.Join(dir, "sub")
|
||
|
|
if err := os.MkdirAll(subdir, 0755); err != nil {
|
||
|
|
t.Fatalf("failed to create subdir: %v", err)
|
||
|
|
}
|
||
|
|
path := filepath.Join(subdir, "train.jsonl")
|
||
|
|
|
||
|
|
responses := []Response{
|
||
|
|
{ID: "1", Prompt: "hi", Response: "hello", Model: "m"},
|
||
|
|
}
|
||
|
|
|
||
|
|
if err := writeTrainingJSONL(path, responses); err != nil {
|
||
|
|
t.Fatalf("writeTrainingJSONL() error: %v", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
if _, err := os.Stat(path); os.IsNotExist(err) {
|
||
|
|
t.Error("expected file to be created")
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestExportEndToEnd(t *testing.T) {
|
||
|
|
dir := t.TempDir()
|
||
|
|
inputPath := filepath.Join(dir, "golden.jsonl")
|
||
|
|
outputDir := filepath.Join(dir, "output")
|
||
|
|
|
||
|
|
// Create input with a mix of valid and invalid responses.
|
||
|
|
validResponse := strings.Repeat("This is a valid response with sufficient length. ", 3)
|
||
|
|
lines := []string{
|
||
|
|
mustJSON(t, Response{ID: "1", Prompt: "p1", Response: validResponse, Model: "m1", Domain: "d1"}),
|
||
|
|
mustJSON(t, Response{ID: "2", Prompt: "p2", Response: validResponse, Model: "m1", Domain: "d1"}),
|
||
|
|
mustJSON(t, Response{ID: "3", Prompt: "p3", Response: validResponse, Model: "m1", Domain: "d1"}),
|
||
|
|
mustJSON(t, Response{ID: "4", Prompt: "p4", Response: validResponse, Model: "m1", Domain: "d1"}),
|
||
|
|
mustJSON(t, Response{ID: "5", Prompt: "p5", Response: validResponse, Model: "m1", Domain: "d1"}),
|
||
|
|
mustJSON(t, Response{ID: "6", Prompt: "p6", Response: validResponse, Model: "m1", Domain: "d1"}),
|
||
|
|
mustJSON(t, Response{ID: "7", Prompt: "p7", Response: validResponse, Model: "m1", Domain: "d1"}),
|
||
|
|
mustJSON(t, Response{ID: "8", Prompt: "p8", Response: validResponse, Model: "m1", Domain: "d1"}),
|
||
|
|
mustJSON(t, Response{ID: "9", Prompt: "p9", Response: validResponse, Model: "m1", Domain: "d1"}),
|
||
|
|
mustJSON(t, Response{ID: "10", Prompt: "p10", Response: validResponse, Model: "m1", Domain: "d1"}),
|
||
|
|
// These should be filtered out.
|
||
|
|
mustJSON(t, Response{ID: "11", Prompt: "p11", Response: "", Model: "m1"}),
|
||
|
|
mustJSON(t, Response{ID: "12", Prompt: "p12", Response: "ERROR: timeout", Model: "m1"}),
|
||
|
|
mustJSON(t, Response{ID: "13", Prompt: "p13", Response: "short", Model: "m1"}),
|
||
|
|
}
|
||
|
|
|
||
|
|
if err := os.WriteFile(inputPath, []byte(strings.Join(lines, "\n")+"\n"), 0644); err != nil {
|
||
|
|
t.Fatalf("failed to write input: %v", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
// Run export with 80/10/10 split.
|
||
|
|
args := []string{
|
||
|
|
"--input", inputPath,
|
||
|
|
"--output-dir", outputDir,
|
||
|
|
"--train-pct", "80",
|
||
|
|
"--valid-pct", "10",
|
||
|
|
"--test-pct", "10",
|
||
|
|
"--seed", "42",
|
||
|
|
}
|
||
|
|
runExport(args)
|
||
|
|
|
||
|
|
// Verify output files exist.
|
||
|
|
for _, name := range []string{"train.jsonl", "valid.jsonl", "test.jsonl"} {
|
||
|
|
path := filepath.Join(outputDir, name)
|
||
|
|
if _, err := os.Stat(path); os.IsNotExist(err) {
|
||
|
|
t.Errorf("expected %s to exist", path)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// Count lines in each file.
|
||
|
|
trainCount := countLines(t, filepath.Join(outputDir, "train.jsonl"))
|
||
|
|
validCount := countLines(t, filepath.Join(outputDir, "valid.jsonl"))
|
||
|
|
testCount := countLines(t, filepath.Join(outputDir, "test.jsonl"))
|
||
|
|
|
||
|
|
total := trainCount + validCount + testCount
|
||
|
|
if total != 10 {
|
||
|
|
t.Errorf("total exported = %d, want 10 (3 should be filtered)", total)
|
||
|
|
}
|
||
|
|
|
||
|
|
// Train should be 80% of 10 = 8.
|
||
|
|
if trainCount != 8 {
|
||
|
|
t.Errorf("train count = %d, want 8", trainCount)
|
||
|
|
}
|
||
|
|
|
||
|
|
// Valid should be 10% of 10 = 1.
|
||
|
|
if validCount != 1 {
|
||
|
|
t.Errorf("valid count = %d, want 1", validCount)
|
||
|
|
}
|
||
|
|
|
||
|
|
// Test gets the remainder: 10 - 8 - 1 = 1.
|
||
|
|
if testCount != 1 {
|
||
|
|
t.Errorf("test count = %d, want 1", testCount)
|
||
|
|
}
|
||
|
|
|
||
|
|
// Verify output format: each line should be a valid TrainingExample.
|
||
|
|
verifyTrainingFormat(t, filepath.Join(outputDir, "train.jsonl"))
|
||
|
|
verifyTrainingFormat(t, filepath.Join(outputDir, "valid.jsonl"))
|
||
|
|
verifyTrainingFormat(t, filepath.Join(outputDir, "test.jsonl"))
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestExportPercentageValidation(t *testing.T) {
|
||
|
|
tests := []struct {
|
||
|
|
name string
|
||
|
|
trainPct, validPct, testPct int
|
||
|
|
wantErr bool
|
||
|
|
}{
|
||
|
|
{"valid 90/5/5", 90, 5, 5, false},
|
||
|
|
{"valid 80/10/10", 80, 10, 10, false},
|
||
|
|
{"valid 100/0/0", 100, 0, 0, false},
|
||
|
|
{"invalid sum 90/10/10", 90, 10, 10, true},
|
||
|
|
{"invalid sum 50/50/50", 50, 50, 50, true},
|
||
|
|
{"invalid negative", -10, 60, 50, true},
|
||
|
|
{"invalid sum too low", 80, 5, 5, true},
|
||
|
|
}
|
||
|
|
|
||
|
|
for _, tt := range tests {
|
||
|
|
t.Run(tt.name, func(t *testing.T) {
|
||
|
|
err := validatePercentages(tt.trainPct, tt.validPct, tt.testPct)
|
||
|
|
if tt.wantErr && err == nil {
|
||
|
|
t.Error("expected error, got nil")
|
||
|
|
}
|
||
|
|
if !tt.wantErr && err != nil {
|
||
|
|
t.Errorf("unexpected error: %v", err)
|
||
|
|
}
|
||
|
|
})
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// Helper functions.
|
||
|
|
|
||
|
|
func mustJSON(t *testing.T, v interface{}) string {
|
||
|
|
t.Helper()
|
||
|
|
data, err := json.Marshal(v)
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("failed to marshal: %v", err)
|
||
|
|
}
|
||
|
|
return string(data)
|
||
|
|
}
|
||
|
|
|
||
|
|
func countLines(t *testing.T, path string) int {
|
||
|
|
t.Helper()
|
||
|
|
f, err := os.Open(path)
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("failed to open %s: %v", path, err)
|
||
|
|
}
|
||
|
|
defer f.Close()
|
||
|
|
|
||
|
|
count := 0
|
||
|
|
scanner := bufio.NewScanner(f)
|
||
|
|
for scanner.Scan() {
|
||
|
|
line := strings.TrimSpace(scanner.Text())
|
||
|
|
if line != "" {
|
||
|
|
count++
|
||
|
|
}
|
||
|
|
}
|
||
|
|
return count
|
||
|
|
}
|
||
|
|
|
||
|
|
func verifyTrainingFormat(t *testing.T, path string) {
|
||
|
|
t.Helper()
|
||
|
|
f, err := os.Open(path)
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("failed to open %s: %v", path, err)
|
||
|
|
}
|
||
|
|
defer f.Close()
|
||
|
|
|
||
|
|
scanner := bufio.NewScanner(f)
|
||
|
|
lineNum := 0
|
||
|
|
for scanner.Scan() {
|
||
|
|
lineNum++
|
||
|
|
line := strings.TrimSpace(scanner.Text())
|
||
|
|
if line == "" {
|
||
|
|
continue
|
||
|
|
}
|
||
|
|
|
||
|
|
var ex TrainingExample
|
||
|
|
if err := json.Unmarshal([]byte(line), &ex); err != nil {
|
||
|
|
t.Errorf("%s line %d: invalid JSON: %v", path, lineNum, err)
|
||
|
|
continue
|
||
|
|
}
|
||
|
|
|
||
|
|
if len(ex.Messages) != 2 {
|
||
|
|
t.Errorf("%s line %d: expected 2 messages, got %d", path, lineNum, len(ex.Messages))
|
||
|
|
continue
|
||
|
|
}
|
||
|
|
|
||
|
|
if ex.Messages[0].Role != "user" {
|
||
|
|
t.Errorf("%s line %d: messages[0].role = %q, want %q", path, lineNum, ex.Messages[0].Role, "user")
|
||
|
|
}
|
||
|
|
if ex.Messages[1].Role != "assistant" {
|
||
|
|
t.Errorf("%s line %d: messages[1].role = %q, want %q", path, lineNum, ex.Messages[1].Role, "assistant")
|
||
|
|
}
|
||
|
|
if ex.Messages[0].Content == "" {
|
||
|
|
t.Errorf("%s line %d: messages[0].content is empty", path, lineNum)
|
||
|
|
}
|
||
|
|
if ex.Messages[1].Content == "" {
|
||
|
|
t.Errorf("%s line %d: messages[1].content is empty", path, lineNum)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|