package inference import ( "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) // --- GenerateConfig defaults --- func TestDefaultGenerateConfig_Good(t *testing.T) { cfg := DefaultGenerateConfig() assert.Equal(t, 256, cfg.MaxTokens, "default MaxTokens should be 256") assert.Equal(t, float32(0.0), cfg.Temperature, "default Temperature should be 0.0 (greedy)") assert.Equal(t, 0, cfg.TopK, "default TopK should be 0 (disabled)") assert.Equal(t, float32(0.0), cfg.TopP, "default TopP should be 0.0 (disabled)") assert.Nil(t, cfg.StopTokens, "default StopTokens should be nil") assert.Equal(t, float32(0.0), cfg.RepeatPenalty, "default RepeatPenalty should be 0.0 (disabled)") assert.False(t, cfg.ReturnLogits, "default ReturnLogits should be false") } // --- WithMaxTokens --- func TestWithMaxTokens_Good(t *testing.T) { tests := []struct { name string val int want int }{ {"small", 32, 32}, {"medium", 512, 512}, {"large", 4096, 4096}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { cfg := ApplyGenerateOpts([]GenerateOption{WithMaxTokens(tt.val)}) assert.Equal(t, tt.want, cfg.MaxTokens) }) } } func TestWithMaxTokens_Bad(t *testing.T) { // Zero and negative values are accepted (no validation in options layer) cfg := ApplyGenerateOpts([]GenerateOption{WithMaxTokens(0)}) assert.Equal(t, 0, cfg.MaxTokens) cfg = ApplyGenerateOpts([]GenerateOption{WithMaxTokens(-1)}) assert.Equal(t, -1, cfg.MaxTokens) } // --- WithTemperature --- func TestWithTemperature_Good(t *testing.T) { tests := []struct { name string val float32 want float32 }{ {"greedy", 0.0, 0.0}, {"low", 0.3, 0.3}, {"default_creative", 0.7, 0.7}, {"high", 1.5, 1.5}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { cfg := ApplyGenerateOpts([]GenerateOption{WithTemperature(tt.val)}) assert.InDelta(t, tt.want, cfg.Temperature, 0.0001) }) } } // --- WithTopK --- func TestWithTopK_Good(t *testing.T) { tests := []struct { name string val int want int }{ {"disabled", 0, 0}, {"small", 10, 10}, {"typical", 40, 40}, {"large", 100, 100}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { cfg := ApplyGenerateOpts([]GenerateOption{WithTopK(tt.val)}) assert.Equal(t, tt.want, cfg.TopK) }) } } // --- WithTopP --- func TestWithTopP_Good(t *testing.T) { tests := []struct { name string val float32 want float32 }{ {"disabled", 0.0, 0.0}, {"tight", 0.5, 0.5}, {"typical", 0.9, 0.9}, {"full", 1.0, 1.0}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { cfg := ApplyGenerateOpts([]GenerateOption{WithTopP(tt.val)}) assert.InDelta(t, tt.want, cfg.TopP, 0.0001) }) } } // --- WithStopTokens --- func TestWithStopTokens_Good(t *testing.T) { t.Run("single", func(t *testing.T) { cfg := ApplyGenerateOpts([]GenerateOption{WithStopTokens(1)}) assert.Equal(t, []int32{1}, cfg.StopTokens) }) t.Run("multiple", func(t *testing.T) { cfg := ApplyGenerateOpts([]GenerateOption{WithStopTokens(1, 2, 3)}) assert.Equal(t, []int32{1, 2, 3}, cfg.StopTokens) }) } func TestWithStopTokens_Ugly(t *testing.T) { // Last call wins — stop tokens are replaced, not merged. cfg := ApplyGenerateOpts([]GenerateOption{ WithStopTokens(1, 2), WithStopTokens(3, 4, 5), }) assert.Equal(t, []int32{3, 4, 5}, cfg.StopTokens, "last WithStopTokens should win") } // --- WithRepeatPenalty --- func TestWithRepeatPenalty_Good(t *testing.T) { tests := []struct { name string val float32 want float32 }{ {"disabled", 0.0, 0.0}, {"no_penalty", 1.0, 1.0}, {"typical", 1.1, 1.1}, {"strong", 2.0, 2.0}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { cfg := ApplyGenerateOpts([]GenerateOption{WithRepeatPenalty(tt.val)}) assert.InDelta(t, tt.want, cfg.RepeatPenalty, 0.0001) }) } } // --- WithLogits --- func TestWithLogits_Good(t *testing.T) { cfg := ApplyGenerateOpts([]GenerateOption{WithLogits()}) assert.True(t, cfg.ReturnLogits) } // --- ApplyGenerateOpts --- func TestApplyGenerateOpts_Good(t *testing.T) { t.Run("nil_opts_returns_defaults", func(t *testing.T) { cfg := ApplyGenerateOpts(nil) def := DefaultGenerateConfig() assert.Equal(t, def, cfg) }) t.Run("empty_opts_returns_defaults", func(t *testing.T) { cfg := ApplyGenerateOpts([]GenerateOption{}) def := DefaultGenerateConfig() assert.Equal(t, def, cfg) }) t.Run("all_options_combined", func(t *testing.T) { cfg := ApplyGenerateOpts([]GenerateOption{ WithMaxTokens(128), WithTemperature(0.7), WithTopK(40), WithTopP(0.9), WithStopTokens(1, 2), WithRepeatPenalty(1.1), WithLogits(), }) assert.Equal(t, 128, cfg.MaxTokens) assert.InDelta(t, 0.7, cfg.Temperature, 0.0001) assert.Equal(t, 40, cfg.TopK) assert.InDelta(t, 0.9, cfg.TopP, 0.0001) assert.Equal(t, []int32{1, 2}, cfg.StopTokens) assert.InDelta(t, 1.1, cfg.RepeatPenalty, 0.0001) assert.True(t, cfg.ReturnLogits) }) } func TestApplyGenerateOpts_Ugly(t *testing.T) { t.Run("last_option_wins", func(t *testing.T) { cfg := ApplyGenerateOpts([]GenerateOption{ WithMaxTokens(100), WithMaxTokens(200), WithMaxTokens(300), }) assert.Equal(t, 300, cfg.MaxTokens, "last WithMaxTokens should win") }) t.Run("temperature_override", func(t *testing.T) { cfg := ApplyGenerateOpts([]GenerateOption{ WithTemperature(0.5), WithTemperature(1.0), }) assert.InDelta(t, 1.0, cfg.Temperature, 0.0001, "last WithTemperature should win") }) } // --- LoadConfig defaults --- func TestApplyLoadOpts_Good_Defaults(t *testing.T) { cfg := ApplyLoadOpts(nil) assert.Equal(t, "", cfg.Backend, "default Backend should be empty (auto-detect)") assert.Equal(t, 0, cfg.ContextLen, "default ContextLen should be 0 (model default)") assert.Equal(t, -1, cfg.GPULayers, "default GPULayers should be -1 (all layers)") assert.Equal(t, 0, cfg.ParallelSlots, "default ParallelSlots should be 0 (server default)") } // --- WithBackend --- func TestWithBackend_Good(t *testing.T) { tests := []struct { name string backend string }{ {"metal", "metal"}, {"rocm", "rocm"}, {"llama_cpp", "llama_cpp"}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { cfg := ApplyLoadOpts([]LoadOption{WithBackend(tt.backend)}) assert.Equal(t, tt.backend, cfg.Backend) }) } } func TestWithBackend_Bad(t *testing.T) { // Empty string is valid at the options layer (means auto-detect). cfg := ApplyLoadOpts([]LoadOption{WithBackend("")}) assert.Equal(t, "", cfg.Backend) } // --- WithContextLen --- func TestWithContextLen_Good(t *testing.T) { tests := []struct { name string val int want int }{ {"small", 2048, 2048}, {"medium", 4096, 4096}, {"large", 32768, 32768}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { cfg := ApplyLoadOpts([]LoadOption{WithContextLen(tt.val)}) assert.Equal(t, tt.want, cfg.ContextLen) }) } } // --- WithGPULayers --- func TestWithGPULayers_Good(t *testing.T) { tests := []struct { name string val int want int }{ {"all", -1, -1}, {"none", 0, 0}, {"partial", 24, 24}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { cfg := ApplyLoadOpts([]LoadOption{WithGPULayers(tt.val)}) assert.Equal(t, tt.want, cfg.GPULayers) }) } } func TestWithGPULayers_Ugly(t *testing.T) { // Override the default -1 with 0 cfg := ApplyLoadOpts([]LoadOption{WithGPULayers(0)}) assert.Equal(t, 0, cfg.GPULayers, "WithGPULayers(0) should override default -1") } // --- WithParallelSlots --- func TestWithParallelSlots_Good(t *testing.T) { tests := []struct { name string val int want int }{ {"default", 0, 0}, {"one", 1, 1}, {"four", 4, 4}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { cfg := ApplyLoadOpts([]LoadOption{WithParallelSlots(tt.val)}) assert.Equal(t, tt.want, cfg.ParallelSlots) }) } } // --- ApplyLoadOpts combined --- func TestApplyLoadOpts_Good_Combined(t *testing.T) { cfg := ApplyLoadOpts([]LoadOption{ WithBackend("rocm"), WithContextLen(8192), WithGPULayers(32), WithParallelSlots(2), }) assert.Equal(t, "rocm", cfg.Backend) assert.Equal(t, 8192, cfg.ContextLen) assert.Equal(t, 32, cfg.GPULayers) assert.Equal(t, 2, cfg.ParallelSlots) } func TestApplyLoadOpts_Ugly(t *testing.T) { t.Run("last_option_wins", func(t *testing.T) { cfg := ApplyLoadOpts([]LoadOption{ WithBackend("metal"), WithBackend("rocm"), }) assert.Equal(t, "rocm", cfg.Backend, "last WithBackend should win") }) t.Run("empty_slice_returns_defaults", func(t *testing.T) { cfg := ApplyLoadOpts([]LoadOption{}) require.Equal(t, -1, cfg.GPULayers, "empty opts should keep default GPULayers=-1") assert.Equal(t, "", cfg.Backend) }) }