diff --git a/inference_test.go b/inference_test.go index cefd2aa..31c2168 100644 --- a/inference_test.go +++ b/inference_test.go @@ -29,7 +29,7 @@ type stubBackend struct { loadErr error } -func (s *stubBackend) Name() string { return s.name } +func (s *stubBackend) Name() string { return s.name } func (s *stubBackend) Available() bool { return s.available } func (s *stubBackend) LoadModel(path string, opts ...LoadOption) (TextModel, error) { if s.loadErr != nil { @@ -45,8 +45,8 @@ type capturingBackend struct { capturedOpts []LoadOption } -func (c *capturingBackend) Name() string { return c.name } -func (c *capturingBackend) Available() bool { return c.available } +func (c *capturingBackend) Name() string { return c.name } +func (c *capturingBackend) Available() bool { return c.available } func (c *capturingBackend) LoadModel(path string, opts ...LoadOption) (TextModel, error) { c.capturedOpts = opts return &stubTextModel{backend: c.name, path: path}, nil @@ -71,10 +71,10 @@ func (m *stubTextModel) BatchGenerate(_ context.Context, _ []string, _ ...Genera return nil, nil } func (m *stubTextModel) ModelType() string { return "stub" } -func (m *stubTextModel) Info() ModelInfo { return ModelInfo{} } -func (m *stubTextModel) Metrics() GenerateMetrics { return GenerateMetrics{} } -func (m *stubTextModel) Err() error { return nil } -func (m *stubTextModel) Close() error { return nil } +func (m *stubTextModel) Info() ModelInfo { return ModelInfo{} } +func (m *stubTextModel) Metrics() GenerateMetrics { return GenerateMetrics{} } +func (m *stubTextModel) Err() error { return nil } +func (m *stubTextModel) Close() error { return nil } // --- Register --- @@ -493,27 +493,21 @@ func TestRegistry_Good_ConcurrentAccess(t *testing.T) { // Concurrent readers interleaved with writers. for range 20 { - wg.Add(1) - go func() { - defer wg.Done() + wg.Go(func() { _ = List() - }() + }) } for range 20 { - wg.Add(1) - go func() { - defer wg.Done() + wg.Go(func() { _, _ = Get("backend_0") - }() + }) } for range 10 { - wg.Add(1) - go func() { - defer wg.Done() + wg.Go(func() { _, _ = Default() - }() + }) } wg.Wait() diff --git a/options.go b/options.go index c5f48e9..db8f24e 100644 --- a/options.go +++ b/options.go @@ -72,6 +72,7 @@ type LoadConfig struct { ContextLen int // Context window size (0 = model default) GPULayers int // Number of layers to offload to GPU (-1 = all, 0 = none) ParallelSlots int // Number of concurrent inference slots (0 = server default) + AdapterPath string // Path to LoRA adapter directory (empty = no adapter) } // LoadOption configures model loading. @@ -100,6 +101,14 @@ func WithParallelSlots(n int) LoadOption { return func(c *LoadConfig) { c.ParallelSlots = n } } +// WithAdapterPath sets the path to a LoRA adapter directory. +// The directory should contain adapter_config.json and adapter safetensors files. +// The adapter weights are loaded and injected into the model at load time, +// enabling inference with a fine-tuned adapter without fusing/merging first. +func WithAdapterPath(path string) LoadOption { + return func(c *LoadConfig) { c.AdapterPath = path } +} + // ApplyLoadOpts builds a LoadConfig from options. func ApplyLoadOpts(opts []LoadOption) LoadConfig { cfg := LoadConfig{ diff --git a/options_test.go b/options_test.go index 63522c4..3f1e313 100644 --- a/options_test.go +++ b/options_test.go @@ -467,4 +467,51 @@ func TestApplyLoadOpts_Ugly(t *testing.T) { }) assert.Equal(t, 1, cfg.ParallelSlots, "last WithParallelSlots should win") }) + + t.Run("adapter_path_override", func(t *testing.T) { + cfg := ApplyLoadOpts([]LoadOption{ + WithAdapterPath("/path/a"), + WithAdapterPath("/path/b"), + }) + assert.Equal(t, "/path/b", cfg.AdapterPath, "last WithAdapterPath should win") + }) +} + +// --- WithAdapterPath --- + +func TestWithAdapterPath_Good(t *testing.T) { + tests := []struct { + name string + val string + want string + }{ + {"simple", "/path/to/adapter", "/path/to/adapter"}, + {"relative", "adapters/lora-v1", "adapters/lora-v1"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := ApplyLoadOpts([]LoadOption{WithAdapterPath(tt.val)}) + assert.Equal(t, tt.want, cfg.AdapterPath) + }) + } +} + +func TestWithAdapterPath_Bad(t *testing.T) { + // Empty string is valid at the options layer (means no adapter). + cfg := ApplyLoadOpts([]LoadOption{WithAdapterPath("")}) + assert.Equal(t, "", cfg.AdapterPath) +} + +func TestWithAdapterPath_Good_DefaultIsEmpty(t *testing.T) { + cfg := ApplyLoadOpts(nil) + assert.Equal(t, "", cfg.AdapterPath, "default AdapterPath should be empty") +} + +func TestWithAdapterPath_Good_OtherFieldsUnchanged(t *testing.T) { + cfg := ApplyLoadOpts([]LoadOption{WithAdapterPath("/some/path")}) + assert.Equal(t, "", cfg.Backend, "Backend should remain at default") + assert.Equal(t, 0, cfg.ContextLen, "ContextLen should remain at default") + assert.Equal(t, -1, cfg.GPULayers, "GPULayers should remain at default") + assert.Equal(t, 0, cfg.ParallelSlots, "ParallelSlots should remain at default") + assert.Equal(t, "/some/path", cfg.AdapterPath) }