go-ml/adapter.go
Snider 747e703c7b feat: Phase 2 backend consolidation — Message alias, GenOpts, deprecation
- Replace Message struct with type alias for inference.Message (backward compat)
- Remove convertMessages() — types are now identical via alias
- Extend GenOpts with TopK, TopP, RepeatPenalty (mapped in convertOpts)
- Deprecate StreamingBackend with doc comment (only 2 callers, both in cli/)
- Simplify HTTPTextModel.Chat() — pass messages directly
- Update CLAUDE.md with Backend Architecture section
- Add 2 new tests, remove 1 obsolete test

Co-Authored-By: Virgil <virgil@lethean.io>
2026-02-20 02:05:59 +00:00

120 lines
4 KiB
Go

// SPDX-Licence-Identifier: EUPL-1.2
package ml
import (
"context"
"strings"
"forge.lthn.ai/core/go-inference"
)
// InferenceAdapter bridges a go-inference TextModel (iter.Seq[Token]) to the
// ml.Backend and ml.StreamingBackend interfaces (string returns / TokenCallback).
//
// This is the key adapter for Phase 1: any go-inference backend (MLX Metal,
// ROCm, llama.cpp) can be wrapped to satisfy go-ml's Backend contract.
type InferenceAdapter struct {
model inference.TextModel
name string
}
// Compile-time checks.
var _ Backend = (*InferenceAdapter)(nil)
var _ StreamingBackend = (*InferenceAdapter)(nil)
// NewInferenceAdapter wraps a go-inference TextModel as an ml.Backend and
// ml.StreamingBackend. The name is used for Backend.Name() (e.g. "mlx").
func NewInferenceAdapter(model inference.TextModel, name string) *InferenceAdapter {
return &InferenceAdapter{model: model, name: name}
}
// Generate collects all tokens from the model's iterator into a single string.
func (a *InferenceAdapter) Generate(ctx context.Context, prompt string, opts GenOpts) (string, error) {
inferOpts := convertOpts(opts)
var b strings.Builder
for tok := range a.model.Generate(ctx, prompt, inferOpts...) {
b.WriteString(tok.Text)
}
if err := a.model.Err(); err != nil {
return b.String(), err
}
return b.String(), nil
}
// Chat sends a multi-turn conversation to the underlying TextModel and collects
// all tokens. Since ml.Message is now a type alias for inference.Message, no
// conversion is needed.
func (a *InferenceAdapter) Chat(ctx context.Context, messages []Message, opts GenOpts) (string, error) {
inferOpts := convertOpts(opts)
var b strings.Builder
for tok := range a.model.Chat(ctx, messages, inferOpts...) {
b.WriteString(tok.Text)
}
if err := a.model.Err(); err != nil {
return b.String(), err
}
return b.String(), nil
}
// GenerateStream forwards each generated token's text to the callback.
// Returns nil on success, the callback's error if it stops early, or the
// model's error if generation fails.
func (a *InferenceAdapter) GenerateStream(ctx context.Context, prompt string, opts GenOpts, cb TokenCallback) error {
inferOpts := convertOpts(opts)
for tok := range a.model.Generate(ctx, prompt, inferOpts...) {
if err := cb(tok.Text); err != nil {
return err
}
}
return a.model.Err()
}
// ChatStream forwards each generated chat token's text to the callback.
// Since ml.Message is now a type alias for inference.Message, no conversion
// is needed.
func (a *InferenceAdapter) ChatStream(ctx context.Context, messages []Message, opts GenOpts, cb TokenCallback) error {
inferOpts := convertOpts(opts)
for tok := range a.model.Chat(ctx, messages, inferOpts...) {
if err := cb(tok.Text); err != nil {
return err
}
}
return a.model.Err()
}
// Name returns the backend identifier set at construction.
func (a *InferenceAdapter) Name() string { return a.name }
// Available always returns true — the model is already loaded.
func (a *InferenceAdapter) Available() bool { return true }
// Close delegates to the underlying TextModel.Close(), releasing GPU memory
// and other resources.
func (a *InferenceAdapter) Close() error { return a.model.Close() }
// Model returns the underlying go-inference TextModel for direct access
// to Classify, BatchGenerate, Metrics, Info, etc.
func (a *InferenceAdapter) Model() inference.TextModel { return a.model }
// convertOpts maps ml.GenOpts to go-inference functional options.
func convertOpts(opts GenOpts) []inference.GenerateOption {
var out []inference.GenerateOption
if opts.Temperature != 0 {
out = append(out, inference.WithTemperature(float32(opts.Temperature)))
}
if opts.MaxTokens != 0 {
out = append(out, inference.WithMaxTokens(opts.MaxTokens))
}
if opts.TopK > 0 {
out = append(out, inference.WithTopK(opts.TopK))
}
if opts.TopP > 0 {
out = append(out, inference.WithTopP(float32(opts.TopP)))
}
if opts.RepeatPenalty > 0 {
out = append(out, inference.WithRepeatPenalty(float32(opts.RepeatPenalty)))
}
// GenOpts.Model is ignored — the model is already loaded.
return out
}