diff --git a/TODO.md b/TODO.md index 6b3afaa..e441652 100644 --- a/TODO.md +++ b/TODO.md @@ -10,11 +10,11 @@ Everything downstream is blocked on this. The old `backend_mlx.go` imports go-ml ### Step 1.1: Add go-inference dependency -- [ ] **Add `forge.lthn.ai/core/go-inference` to go.mod** — Already has a `replace` directive pointing to `../go-inference`. Run `go get forge.lthn.ai/core/go-inference` then `go mod tidy`. Verify the module resolves. +- [x] **Add `forge.lthn.ai/core/go-inference` to go.mod** — Already has a `replace` directive pointing to `../go-inference`. Run `go get forge.lthn.ai/core/go-inference` then `go mod tidy`. Verify the module resolves. ### Step 1.2: Write the InferenceAdapter -- [ ] **Create `adapter.go`** — Bridge between `go-inference.TextModel` (returns `iter.Seq[Token]`) and `ml.Backend` + `ml.StreamingBackend` (returns `string`/callback). Must implement: +- [x] **Create `adapter.go`** — Bridge between `go-inference.TextModel` (returns `iter.Seq[Token]`) and `ml.Backend` + `ml.StreamingBackend` (returns `string`/callback). Must implement: - `Generate()` — collect tokens from iterator into string - `Chat()` — same, using `TextModel.Chat()` - `GenerateStream()` — forward tokens to `TokenCallback` @@ -32,7 +32,7 @@ Everything downstream is blocked on this. The old `backend_mlx.go` imports go-ml **Error handling**: After the iterator completes, check `model.Err()` to distinguish EOS from errors (OOM, ctx cancelled). -- [ ] **Test adapter.go** — Test with a mock `inference.TextModel` that yields predetermined tokens. Test cases: +- [x] **Test adapter.go** — 13 test cases with mock TextModel (all pass). Test cases: - Normal generation (collect tokens → string) - Streaming (each token hits callback) - Callback error stops iteration @@ -42,7 +42,7 @@ Everything downstream is blocked on this. The old `backend_mlx.go` imports go-ml ### Step 1.3: Rewrite backend_mlx.go -- [ ] **Replace backend_mlx.go** — Delete the 253 LOC that manually handle tokenisation, KV cache, sampling, and memory cleanup. Replace with ~60 LOC: +- [x] **Replace backend_mlx.go** — Deleted the 253 LOC that manually handle tokenisation, KV cache, sampling, and memory cleanup. Replaced with ~35 LOC: ```go //go:build darwin && arm64 @@ -63,7 +63,7 @@ Everything downstream is blocked on this. The old `backend_mlx.go` imports go-ml ``` The `InferenceAdapter` from Step 1.2 handles all the Generate/Chat/Stream logic. -- [ ] **Preserve memory controls** — The old `MLXBackend` set cache/memory limits (16GB/24GB). These should be configurable. Options: +- [ ] **Preserve memory controls** — The old `MLXBackend` set cache/memory limits (16GB/24GB). Now delegated to go-mlx internally. Callers can still use `mlx.SetCacheLimit()`/`mlx.SetMemoryLimit()` directly. Options for future: - Accept memory limits in `NewMLXBackend` params - Or set them in `InferenceAdapter` wrapper - go-mlx exposes `SetCacheLimit()` / `SetMemoryLimit()` at package level diff --git a/adapter.go b/adapter.go new file mode 100644 index 0000000..b3ca8fb --- /dev/null +++ b/adapter.go @@ -0,0 +1,118 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +package ml + +import ( + "context" + "strings" + + "forge.lthn.ai/core/go-inference" +) + +// InferenceAdapter bridges a go-inference TextModel (iter.Seq[Token]) to the +// ml.Backend and ml.StreamingBackend interfaces (string returns / TokenCallback). +// +// This is the key adapter for Phase 1: any go-inference backend (MLX Metal, +// ROCm, llama.cpp) can be wrapped to satisfy go-ml's Backend contract. +type InferenceAdapter struct { + model inference.TextModel + name string +} + +// Compile-time checks. +var _ Backend = (*InferenceAdapter)(nil) +var _ StreamingBackend = (*InferenceAdapter)(nil) + +// NewInferenceAdapter wraps a go-inference TextModel as an ml.Backend and +// ml.StreamingBackend. The name is used for Backend.Name() (e.g. "mlx"). +func NewInferenceAdapter(model inference.TextModel, name string) *InferenceAdapter { + return &InferenceAdapter{model: model, name: name} +} + +// Generate collects all tokens from the model's iterator into a single string. +func (a *InferenceAdapter) Generate(ctx context.Context, prompt string, opts GenOpts) (string, error) { + inferOpts := convertOpts(opts) + var b strings.Builder + for tok := range a.model.Generate(ctx, prompt, inferOpts...) { + b.WriteString(tok.Text) + } + if err := a.model.Err(); err != nil { + return b.String(), err + } + return b.String(), nil +} + +// Chat converts ml.Message to inference.Message, then collects all tokens. +func (a *InferenceAdapter) Chat(ctx context.Context, messages []Message, opts GenOpts) (string, error) { + inferMsgs := convertMessages(messages) + inferOpts := convertOpts(opts) + var b strings.Builder + for tok := range a.model.Chat(ctx, inferMsgs, inferOpts...) { + b.WriteString(tok.Text) + } + if err := a.model.Err(); err != nil { + return b.String(), err + } + return b.String(), nil +} + +// GenerateStream forwards each generated token's text to the callback. +// Returns nil on success, the callback's error if it stops early, or the +// model's error if generation fails. +func (a *InferenceAdapter) GenerateStream(ctx context.Context, prompt string, opts GenOpts, cb TokenCallback) error { + inferOpts := convertOpts(opts) + for tok := range a.model.Generate(ctx, prompt, inferOpts...) { + if err := cb(tok.Text); err != nil { + return err + } + } + return a.model.Err() +} + +// ChatStream forwards each generated chat token's text to the callback. +func (a *InferenceAdapter) ChatStream(ctx context.Context, messages []Message, opts GenOpts, cb TokenCallback) error { + inferMsgs := convertMessages(messages) + inferOpts := convertOpts(opts) + for tok := range a.model.Chat(ctx, inferMsgs, inferOpts...) { + if err := cb(tok.Text); err != nil { + return err + } + } + return a.model.Err() +} + +// Name returns the backend identifier set at construction. +func (a *InferenceAdapter) Name() string { return a.name } + +// Available always returns true — the model is already loaded. +func (a *InferenceAdapter) Available() bool { return true } + +// Close delegates to the underlying TextModel.Close(), releasing GPU memory +// and other resources. +func (a *InferenceAdapter) Close() error { return a.model.Close() } + +// Model returns the underlying go-inference TextModel for direct access +// to Classify, BatchGenerate, Metrics, Info, etc. +func (a *InferenceAdapter) Model() inference.TextModel { return a.model } + +// convertOpts maps ml.GenOpts to go-inference functional options. +func convertOpts(opts GenOpts) []inference.GenerateOption { + var out []inference.GenerateOption + if opts.Temperature != 0 { + out = append(out, inference.WithTemperature(float32(opts.Temperature))) + } + if opts.MaxTokens != 0 { + out = append(out, inference.WithMaxTokens(opts.MaxTokens)) + } + // GenOpts.Model is ignored — the model is already loaded. + return out +} + +// convertMessages maps ml.Message to inference.Message (trivial field copy). +func convertMessages(msgs []Message) []inference.Message { + out := make([]inference.Message, len(msgs)) + for i, m := range msgs { + out[i] = inference.Message{Role: m.Role, Content: m.Content} + } + return out +} diff --git a/adapter_test.go b/adapter_test.go new file mode 100644 index 0000000..8a934fc --- /dev/null +++ b/adapter_test.go @@ -0,0 +1,252 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +package ml + +import ( + "context" + "errors" + "iter" + "testing" + + "forge.lthn.ai/core/go-inference" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// mockTextModel implements inference.TextModel for testing the InferenceAdapter. +type mockTextModel struct { + tokens []inference.Token // tokens to yield + err error // error to return from Err() + closed bool + modelType string +} + +func (m *mockTextModel) Generate(_ context.Context, _ string, _ ...inference.GenerateOption) iter.Seq[inference.Token] { + return func(yield func(inference.Token) bool) { + for _, tok := range m.tokens { + if !yield(tok) { + return + } + } + } +} + +func (m *mockTextModel) Chat(_ context.Context, _ []inference.Message, _ ...inference.GenerateOption) iter.Seq[inference.Token] { + return func(yield func(inference.Token) bool) { + for _, tok := range m.tokens { + if !yield(tok) { + return + } + } + } +} + +func (m *mockTextModel) Classify(_ context.Context, _ []string, _ ...inference.GenerateOption) ([]inference.ClassifyResult, error) { + panic("Classify not used by adapter") +} + +func (m *mockTextModel) BatchGenerate(_ context.Context, _ []string, _ ...inference.GenerateOption) ([]inference.BatchResult, error) { + panic("BatchGenerate not used by adapter") +} + +func (m *mockTextModel) ModelType() string { return m.modelType } +func (m *mockTextModel) Info() inference.ModelInfo { return inference.ModelInfo{} } +func (m *mockTextModel) Metrics() inference.GenerateMetrics { return inference.GenerateMetrics{} } +func (m *mockTextModel) Err() error { return m.err } +func (m *mockTextModel) Close() error { m.closed = true; return nil } + +// --- Tests --- + +func TestInferenceAdapter_Generate_Good(t *testing.T) { + mock := &mockTextModel{ + tokens: []inference.Token{ + {ID: 1, Text: "Hello"}, + {ID: 2, Text: " "}, + {ID: 3, Text: "world"}, + }, + } + adapter := NewInferenceAdapter(mock, "test") + + result, err := adapter.Generate(context.Background(), "prompt", GenOpts{}) + require.NoError(t, err) + assert.Equal(t, "Hello world", result) +} + +func TestInferenceAdapter_Generate_Empty_Good(t *testing.T) { + mock := &mockTextModel{tokens: nil} + adapter := NewInferenceAdapter(mock, "test") + + result, err := adapter.Generate(context.Background(), "prompt", GenOpts{}) + require.NoError(t, err) + assert.Equal(t, "", result) +} + +func TestInferenceAdapter_Generate_ModelError_Bad(t *testing.T) { + mock := &mockTextModel{ + tokens: []inference.Token{ + {ID: 1, Text: "partial"}, + }, + err: errors.New("out of memory"), + } + adapter := NewInferenceAdapter(mock, "test") + + result, err := adapter.Generate(context.Background(), "prompt", GenOpts{}) + assert.Error(t, err) + assert.Contains(t, err.Error(), "out of memory") + // Partial output is still returned. + assert.Equal(t, "partial", result) +} + +func TestInferenceAdapter_GenerateStream_Good(t *testing.T) { + mock := &mockTextModel{ + tokens: []inference.Token{ + {ID: 1, Text: "one"}, + {ID: 2, Text: "two"}, + {ID: 3, Text: "three"}, + }, + } + adapter := NewInferenceAdapter(mock, "test") + + var collected []string + err := adapter.GenerateStream(context.Background(), "prompt", GenOpts{}, func(token string) error { + collected = append(collected, token) + return nil + }) + require.NoError(t, err) + assert.Equal(t, []string{"one", "two", "three"}, collected) +} + +func TestInferenceAdapter_GenerateStream_CallbackError_Bad(t *testing.T) { + callbackErr := errors.New("client disconnected") + mock := &mockTextModel{ + tokens: []inference.Token{ + {ID: 1, Text: "one"}, + {ID: 2, Text: "two"}, + {ID: 3, Text: "three"}, + }, + } + adapter := NewInferenceAdapter(mock, "test") + + count := 0 + err := adapter.GenerateStream(context.Background(), "prompt", GenOpts{}, func(token string) error { + count++ + if count >= 2 { + return callbackErr + } + return nil + }) + assert.ErrorIs(t, err, callbackErr) + assert.Equal(t, 2, count, "callback should have been called exactly twice") +} + +func TestInferenceAdapter_ContextCancellation_Bad(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + + // Create a mock that respects context cancellation. + mock := &mockTextModel{} + mock.tokens = nil // no tokens; the mock Generate just returns empty + // Simulate context cancel causing model error. + cancel() + mock.err = ctx.Err() + + adapter := NewInferenceAdapter(mock, "test") + _, err := adapter.Generate(ctx, "prompt", GenOpts{}) + assert.ErrorIs(t, err, context.Canceled) +} + +func TestInferenceAdapter_Chat_Good(t *testing.T) { + mock := &mockTextModel{ + tokens: []inference.Token{ + {ID: 1, Text: "Hi"}, + {ID: 2, Text: " there"}, + }, + } + adapter := NewInferenceAdapter(mock, "test") + + messages := []Message{ + {Role: "user", Content: "Hello"}, + {Role: "assistant", Content: "Hi"}, + {Role: "user", Content: "How are you?"}, + } + result, err := adapter.Chat(context.Background(), messages, GenOpts{}) + require.NoError(t, err) + assert.Equal(t, "Hi there", result) +} + +func TestInferenceAdapter_ChatStream_Good(t *testing.T) { + mock := &mockTextModel{ + tokens: []inference.Token{ + {ID: 1, Text: "reply"}, + {ID: 2, Text: "!"}, + }, + } + adapter := NewInferenceAdapter(mock, "test") + + messages := []Message{{Role: "user", Content: "test"}} + var collected []string + err := adapter.ChatStream(context.Background(), messages, GenOpts{}, func(token string) error { + collected = append(collected, token) + return nil + }) + require.NoError(t, err) + assert.Equal(t, []string{"reply", "!"}, collected) +} + +func TestInferenceAdapter_ConvertOpts_Good(t *testing.T) { + // Non-zero values should produce options. + opts := convertOpts(GenOpts{Temperature: 0.7, MaxTokens: 512, Model: "ignored"}) + assert.Len(t, opts, 2) + + // Zero values should produce no options. + opts = convertOpts(GenOpts{}) + assert.Len(t, opts, 0) + + // Only temperature set. + opts = convertOpts(GenOpts{Temperature: 0.5}) + assert.Len(t, opts, 1) + + // Only max tokens set. + opts = convertOpts(GenOpts{MaxTokens: 100}) + assert.Len(t, opts, 1) +} + +func TestInferenceAdapter_ConvertMessages_Good(t *testing.T) { + mlMsgs := []Message{ + {Role: "system", Content: "You are helpful."}, + {Role: "user", Content: "Hello"}, + {Role: "assistant", Content: "Hi!"}, + } + inferMsgs := convertMessages(mlMsgs) + require.Len(t, inferMsgs, 3) + assert.Equal(t, "system", inferMsgs[0].Role) + assert.Equal(t, "You are helpful.", inferMsgs[0].Content) + assert.Equal(t, "user", inferMsgs[1].Role) + assert.Equal(t, "Hello", inferMsgs[1].Content) + assert.Equal(t, "assistant", inferMsgs[2].Role) + assert.Equal(t, "Hi!", inferMsgs[2].Content) +} + +func TestInferenceAdapter_NameAndAvailable_Good(t *testing.T) { + mock := &mockTextModel{} + adapter := NewInferenceAdapter(mock, "mlx") + + assert.Equal(t, "mlx", adapter.Name()) + assert.True(t, adapter.Available()) +} + +func TestInferenceAdapter_Close_Good(t *testing.T) { + mock := &mockTextModel{} + adapter := NewInferenceAdapter(mock, "test") + + err := adapter.Close() + require.NoError(t, err) + assert.True(t, mock.closed) +} + +func TestInferenceAdapter_Model_Good(t *testing.T) { + mock := &mockTextModel{modelType: "qwen3"} + adapter := NewInferenceAdapter(mock, "test") + + assert.Equal(t, "qwen3", adapter.Model().ModelType()) +} diff --git a/backend_mlx.go b/backend_mlx.go index d7596a0..afe2e00 100644 --- a/backend_mlx.go +++ b/backend_mlx.go @@ -1,253 +1,38 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + //go:build darwin && arm64 package ml import ( - "context" "fmt" "log/slog" - "runtime" - "strings" - "sync" - "forge.lthn.ai/core/go-mlx" - "forge.lthn.ai/core/go-mlx/cache" - "forge.lthn.ai/core/go-mlx/model" - "forge.lthn.ai/core/go-mlx/sample" - "forge.lthn.ai/core/go-mlx/tokenizer" + "forge.lthn.ai/core/go-inference" + _ "forge.lthn.ai/core/go-mlx" // registers "metal" backend via init() ) -// MLXBackend implements Backend and StreamingBackend for native Metal inference. -type MLXBackend struct { - model model.Model - tok *tokenizer.Tokenizer - caches []cache.Cache - sampler sample.Sampler - mu sync.Mutex - modelBytes uint64 // model size at load time, for memory budget -} +// NewMLXBackend loads a model via go-inference's Metal backend and wraps it +// in an InferenceAdapter for use as ml.Backend/StreamingBackend. +// +// The blank import of go-mlx registers the "metal" backend, so +// inference.LoadModel() will automatically use Metal on Apple Silicon. +// +// Load options (context length, etc.) are forwarded directly to go-inference. +func NewMLXBackend(modelPath string, loadOpts ...inference.LoadOption) (*InferenceAdapter, error) { + slog.Info("mlx: loading model via go-inference", "path", modelPath) -// Compile-time check that MLXBackend satisfies StreamingBackend. -var _ StreamingBackend = (*MLXBackend)(nil) - -// NewMLXBackend loads a model from a safetensors directory and creates -// a native Metal inference backend. -func NewMLXBackend(modelPath string) (*MLXBackend, error) { - if !mlx.MetalAvailable() { - return nil, fmt.Errorf("mlx: Metal GPU not available") - } - - slog.Info("mlx: loading model", "path", modelPath) - m, err := model.LoadModel(modelPath) + m, err := inference.LoadModel(modelPath, loadOpts...) if err != nil { - return nil, fmt.Errorf("mlx: load model: %w", err) + return nil, fmt.Errorf("mlx: %w", err) } - // Cap Metal memory: cache limit for allocator reuse, memory limit as hard ceiling. - mlx.SetCacheLimit(16 * 1024 * 1024 * 1024) // 16 GB allocator cache - mlx.SetMemoryLimit(24 * 1024 * 1024 * 1024) // 24 GB hard cap - - modelMB := mlx.GetActiveMemory() / 1024 / 1024 + info := m.Info() slog.Info("mlx: model loaded", - "layers", m.NumLayers(), - "memory_mb", modelMB, + "arch", info.Architecture, + "layers", info.NumLayers, + "quant", info.QuantBits, ) - return &MLXBackend{ - model: m, - tok: m.Tokenizer(), - caches: m.NewCache(), - sampler: sample.New(0.1, 0, 0, 0), - modelBytes: mlx.GetActiveMemory(), - }, nil -} - -// generate is the core token generation loop. If cb is non-nil, each token's -// text is sent to it (streaming mode). Returns the full output text. -func (b *MLXBackend) generate(ctx context.Context, tokens []int32, opts GenOpts, cb TokenCallback) (string, error) { - b.mu.Lock() - defer b.mu.Unlock() - - for _, c := range b.caches { - c.Reset() - } - - temp := float32(opts.Temperature) - if temp == 0 { - temp = 0.1 - } - sampler := sample.New(temp, 0, 0, 0) - - input := mlx.FromValues(tokens, 1, len(tokens)) - - maxTokens := opts.MaxTokens - if maxTokens == 0 { - maxTokens = 2048 - } - - var output []int32 - firstToken := true - for i := 0; i < maxTokens; i++ { - select { - case <-ctx.Done(): - runtime.GC() - mlx.ClearCache() - return b.tok.Decode(output), ctx.Err() - default: - } - - logits := b.model.Forward(input, b.caches) - logits = lastPosition(logits) - next := sampler.Sample(logits) - mlx.Materialize(next) - - nextToken := int32(next.Int()) - if nextToken == b.tok.EOSToken() { - break - } - output = append(output, nextToken) - input = mlx.FromValues([]int32{nextToken}, 1, 1) - - // Stream the token text to the callback - if cb != nil { - tokenText := b.tok.DecodeToken(nextToken) - // Strip the SentencePiece leading space only on the first token - if firstToken { - tokenText = strings.TrimLeft(tokenText, " ") - firstToken = false - } - if err := cb(tokenText); err != nil { - runtime.GC() - mlx.ClearCache() - return b.tok.Decode(output), err - } - } - - if i%4 == 3 { - runtime.GC() - mlx.ClearCache() - } - } - - runtime.GC() - mlx.ClearCache() - b.checkMemory() - return b.tok.Decode(output), nil -} - -// Generate produces text from a prompt using native Metal inference. -func (b *MLXBackend) Generate(ctx context.Context, prompt string, opts GenOpts) (string, error) { - formatted := formatPrompt(b.model.ModelType(), prompt) - tokens := b.tok.Encode(formatted) - return b.generate(ctx, tokens, opts, nil) -} - -// Chat formats messages and generates a response. -func (b *MLXBackend) Chat(ctx context.Context, messages []Message, opts GenOpts) (string, error) { - prompt := formatChat(b.model.ModelType(), messages) - tokens := b.tok.Encode(prompt) - return b.generate(ctx, tokens, opts, nil) -} - -// GenerateStream streams tokens from a single prompt via the callback. -func (b *MLXBackend) GenerateStream(ctx context.Context, prompt string, opts GenOpts, cb TokenCallback) error { - formatted := formatPrompt(b.model.ModelType(), prompt) - tokens := b.tok.Encode(formatted) - _, err := b.generate(ctx, tokens, opts, cb) - return err -} - -// ChatStream streams tokens from a chat conversation via the callback. -func (b *MLXBackend) ChatStream(ctx context.Context, messages []Message, opts GenOpts, cb TokenCallback) error { - prompt := formatChat(b.model.ModelType(), messages) - tokens := b.tok.Encode(prompt) - _, err := b.generate(ctx, tokens, opts, cb) - return err -} - -// lastPosition extracts the last sequence position from [B, L, V] logits → [B, V]. -func lastPosition(logits *mlx.Array) *mlx.Array { - shape := logits.Shape() - if len(shape) == 3 && shape[1] > 1 { - L := shape[1] - logits = mlx.Slice(logits, []int32{0, L - 1, 0}, []int32{shape[0], L, shape[2]}) - logits = mlx.Reshape(logits, shape[0], shape[2]) - } else if len(shape) == 3 && shape[1] == 1 { - logits = mlx.Reshape(logits, shape[0], shape[2]) - } - return logits -} - -// checkMemory logs Metal memory usage and forces cleanup if it exceeds budget. -func (b *MLXBackend) checkMemory() { - active := mlx.GetActiveMemory() - budget := b.modelBytes * 3 - if active > budget { - slog.Warn("mlx: memory over budget, forcing cleanup", - "active_mb", active/1024/1024, - "model_mb", b.modelBytes/1024/1024, - "peak_mb", mlx.GetPeakMemory()/1024/1024, - ) - runtime.GC() - runtime.GC() - mlx.ClearCache() - } -} - -// Name returns the backend identifier. -func (b *MLXBackend) Name() string { return "mlx" } - -// Available reports whether Metal GPU is ready. -func (b *MLXBackend) Available() bool { return mlx.MetalAvailable() } - -// formatPrompt wraps a raw prompt in the model's chat template for single-turn generation. -func formatPrompt(modelType, prompt string) string { - switch modelType { - case "qwen3": - return fmt.Sprintf("<|im_start|>user\n%s<|im_end|>\n<|im_start|>assistant\n", prompt) - default: - return tokenizer.FormatGemmaPrompt(prompt) - } -} - -// formatChat builds a multi-turn chat prompt from messages using the model's template. -func formatChat(modelType string, messages []Message) string { - switch modelType { - case "qwen3": - return formatQwen3Chat(messages) - default: - return formatGemmaChat(messages) - } -} - -func formatGemmaChat(messages []Message) string { - var prompt string - for _, msg := range messages { - switch msg.Role { - case "user": - prompt += fmt.Sprintf("user\n%s\n", msg.Content) - case "assistant": - prompt += fmt.Sprintf("model\n%s\n", msg.Content) - case "system": - prompt += fmt.Sprintf("user\n[System: %s]\n", msg.Content) - } - } - prompt += "model\n" - return prompt -} - -func formatQwen3Chat(messages []Message) string { - var prompt string - for _, msg := range messages { - switch msg.Role { - case "system": - prompt += fmt.Sprintf("<|im_start|>system\n%s<|im_end|>\n", msg.Content) - case "user": - prompt += fmt.Sprintf("<|im_start|>user\n%s<|im_end|>\n", msg.Content) - case "assistant": - prompt += fmt.Sprintf("<|im_start|>assistant\n%s<|im_end|>\n", msg.Content) - } - } - prompt += "<|im_start|>assistant\n" - return prompt + return NewInferenceAdapter(m, "mlx"), nil } diff --git a/go.mod b/go.mod index b306255..cffb57b 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.25.5 require ( forge.lthn.ai/core/go v0.0.0 + forge.lthn.ai/core/go-inference v0.0.0 forge.lthn.ai/core/go-mlx v0.0.0 github.com/marcboeker/go-duckdb v1.8.5 github.com/parquet-go/parquet-go v0.27.0 @@ -14,6 +15,7 @@ require ( github.com/andybalholm/brotli v1.2.0 // indirect github.com/apache/arrow-go/v18 v18.5.1 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect + github.com/go-viper/mapstructure/v2 v2.5.0 // indirect github.com/goccy/go-json v0.10.5 // indirect github.com/google/flatbuffers v25.12.19+incompatible // indirect github.com/google/uuid v1.6.0 // indirect @@ -27,8 +29,13 @@ require ( github.com/zeebo/xxh3 v1.1.0 // indirect golang.org/x/exp v0.0.0-20260212183809-81e46e3db34a // indirect golang.org/x/mod v0.33.0 // indirect + golang.org/x/sync v0.19.0 // indirect golang.org/x/sys v0.41.0 // indirect + golang.org/x/telemetry v0.0.0-20260209163413-e7419c687ee4 // indirect + golang.org/x/tools v0.42.0 // indirect golang.org/x/xerrors v0.0.0-20240903120638-7835f813f4da // indirect + google.golang.org/protobuf v1.36.11 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) replace forge.lthn.ai/core/go => ../host-uk/core diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..3885fd6 --- /dev/null +++ b/go.sum @@ -0,0 +1,89 @@ +github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU= +github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU= +github.com/alecthomas/assert/v2 v2.10.0 h1:jjRCHsj6hBJhkmhznrCzoNpbA3zqy0fYiUcYZP/GkPY= +github.com/alecthomas/assert/v2 v2.10.0/go.mod h1:Bze95FyfUr7x34QZrjL+XP+0qgp/zg8yS+TtBj1WA3k= +github.com/alecthomas/repr v0.4.0 h1:GhI2A8MACjfegCPVq9f1FLvIBS+DrQ2KQBFZP1iFzXc= +github.com/alecthomas/repr v0.4.0/go.mod h1:Fr0507jx4eOXV7AlPV6AVZLYrLIuIeSOWtW57eE/O/4= +github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ= +github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY= +github.com/apache/arrow-go/v18 v18.5.1 h1:yaQ6zxMGgf9YCYw4/oaeOU3AULySDlAYDOcnr4LdHdI= +github.com/apache/arrow-go/v18 v18.5.1/go.mod h1:OCCJsmdq8AsRm8FkBSSmYTwL/s4zHW9CqxeBxEytkNE= +github.com/apache/thrift v0.22.0 h1:r7mTJdj51TMDe6RtcmNdQxgn9XcyfGDOzegMDRg47uc= +github.com/apache/thrift v0.22.0/go.mod h1:1e7J/O1Ae6ZQMTYdy9xa3w9k+XHWPfRvdPyJeynQ+/g= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/go-viper/mapstructure/v2 v2.5.0 h1:vM5IJoUAy3d7zRSVtIwQgBj7BiWtMPfmPEgAXnvj1Ro= +github.com/go-viper/mapstructure/v2 v2.5.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM= +github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4= +github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= +github.com/golang/snappy v1.0.0 h1:Oy607GVXHs7RtbggtPBnr2RmDArIsAefDwvrdWvRhGs= +github.com/golang/snappy v1.0.0/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= +github.com/google/flatbuffers v25.12.19+incompatible h1:haMV2JRRJCe1998HeW/p0X9UaMTK6SDo0ffLn2+DbLs= +github.com/google/flatbuffers v25.12.19+incompatible/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM= +github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg= +github.com/klauspost/asmfmt v1.3.2 h1:4Ri7ox3EwapiOjCki+hw14RyKk201CN4rzyCJRFLpK4= +github.com/klauspost/asmfmt v1.3.2/go.mod h1:AG8TuvYojzulgDAMCnYn50l/5QV3Bs/tp6j0HLHbNSE= +github.com/klauspost/compress v1.18.4 h1:RPhnKRAQ4Fh8zU2FY/6ZFDwTVTxgJ/EMydqSTzE9a2c= +github.com/klauspost/compress v1.18.4/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4= +github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y= +github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/marcboeker/go-duckdb v1.8.5 h1:tkYp+TANippy0DaIOP5OEfBEwbUINqiFqgwMQ44jME0= +github.com/marcboeker/go-duckdb v1.8.5/go.mod h1:6mK7+WQE4P4u5AFLvVBmhFxY5fvhymFptghgJX6B+/8= +github.com/minio/asm2plan9s v0.0.0-20200509001527-cdd76441f9d8 h1:AMFGa4R4MiIpspGNG7Z948v4n35fFGB3RR3G/ry4FWs= +github.com/minio/asm2plan9s v0.0.0-20200509001527-cdd76441f9d8/go.mod h1:mC1jAcsrzbxHt8iiaC+zU4b1ylILSosueou12R++wfY= +github.com/minio/c2goasm v0.0.0-20190812172519-36a3d3bbc4f3 h1:+n/aFZefKZp7spd8DFdX7uMikMLXX4oubIzJF4kv/wI= +github.com/minio/c2goasm v0.0.0-20190812172519-36a3d3bbc4f3/go.mod h1:RagcQ7I8IeTMnF8JTXieKnO4Z6JCsikNEzj0DwauVzE= +github.com/parquet-go/bitpack v1.0.0 h1:AUqzlKzPPXf2bCdjfj4sTeacrUwsT7NlcYDMUQxPcQA= +github.com/parquet-go/bitpack v1.0.0/go.mod h1:XnVk9TH+O40eOOmvpAVZ7K2ocQFrQwysLMnc6M/8lgs= +github.com/parquet-go/jsonlite v1.4.0 h1:RTG7prqfO0HD5egejU8MUDBN8oToMj55cgSV1I0zNW4= +github.com/parquet-go/jsonlite v1.4.0/go.mod h1:nDjpkpL4EOtqs6NQugUsi0Rleq9sW/OtC1NnZEnxzF0= +github.com/parquet-go/parquet-go v0.27.0 h1:vHWK2xaHbj+v1DYps03yDRpEsdtOeKbhiXUaixoPb3g= +github.com/parquet-go/parquet-go v0.27.0/go.mod h1:navtkAYr2LGoJVp141oXPlO/sxLvaOe3la2JEoD8+rg= +github.com/pierrec/lz4/v4 v4.1.25 h1:kocOqRffaIbU5djlIBr7Wh+cx82C0vtFb0fOurZHqD0= +github.com/pierrec/lz4/v4 v4.1.25/go.mod h1:EoQMVJgeeEOMsCqCzqFm2O0cJvljX2nGZjcRIPL34O4= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= +github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +github.com/twpayne/go-geom v1.6.1 h1:iLE+Opv0Ihm/ABIcvQFGIiFBXd76oBIar9drAwHFhR4= +github.com/twpayne/go-geom v1.6.1/go.mod h1:Kr+Nly6BswFsKM5sd31YaoWS5PeDDH2NftJTK7Gd028= +github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU= +github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E= +github.com/zeebo/assert v1.3.0 h1:g7C04CbJuIDKNPFHmsk4hwZDO5O+kntRxzaUoNXj+IQ= +github.com/zeebo/assert v1.3.0/go.mod h1:Pq9JiuJQpG8JLJdtkwrJESF0Foym2/D9XMU5ciN/wJ0= +github.com/zeebo/xxh3 v1.1.0 h1:s7DLGDK45Dyfg7++yxI0khrfwq9661w9EN78eP/UZVs= +github.com/zeebo/xxh3 v1.1.0/go.mod h1:IisAie1LELR4xhVinxWS5+zf1lA4p0MW4T+w+W07F5s= +golang.org/x/exp v0.0.0-20260212183809-81e46e3db34a h1:ovFr6Z0MNmU7nH8VaX5xqw+05ST2uO1exVfZPVqRC5o= +golang.org/x/exp v0.0.0-20260212183809-81e46e3db34a/go.mod h1:K79w1Vqn7PoiZn+TkNpx3BUWUQksGO3JcVX6qIjytmA= +golang.org/x/mod v0.33.0 h1:tHFzIWbBifEmbwtGz65eaWyGiGZatSrT9prnU8DbVL8= +golang.org/x/mod v0.33.0/go.mod h1:swjeQEj+6r7fODbD2cqrnje9PnziFuw4bmLbBZFrQ5w= +golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= +golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= +golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k= +golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/telemetry v0.0.0-20260209163413-e7419c687ee4 h1:bTLqdHv7xrGlFbvf5/TXNxy/iUwwdkjhqQTJDjW7aj0= +golang.org/x/telemetry v0.0.0-20260209163413-e7419c687ee4/go.mod h1:g5NllXBEermZrmR51cJDQxmJUHUOfRAaNyWBM+R+548= +golang.org/x/tools v0.42.0 h1:uNgphsn75Tdz5Ji2q36v/nsFSfR/9BRFvqhGBaJGd5k= +golang.org/x/tools v0.42.0/go.mod h1:Ma6lCIwGZvHK6XtgbswSoWroEkhugApmsXyrUmBhfr0= +golang.org/x/xerrors v0.0.0-20240903120638-7835f813f4da h1:noIWHXmPHxILtqtCOPIhSt0ABwskkZKjD3bXGnZGpNY= +golang.org/x/xerrors v0.0.0-20240903120638-7835f813f4da/go.mod h1:NDW/Ps6MPRej6fsCIbMTohpP40sJ/P/vI1MoTEGwX90= +gonum.org/v1/gonum v0.17.0 h1:VbpOemQlsSMrYmn7T2OUvQ4dqxQXU+ouZFQsZOx50z4= +gonum.org/v1/gonum v0.17.0/go.mod h1:El3tOrEuMpv2UdMrbNlKEh9vd86bmQ6vqIcDwxEOc1E= +google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE= +google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=