// SPDX-Licence-Identifier: EUPL-1.2 package ml import ( "context" "fmt" "strings" "forge.lthn.ai/core/go-inference" ) // InferenceAdapter bridges a go-inference TextModel (iter.Seq[Token]) to the // ml.Backend and ml.StreamingBackend interfaces (string returns / TokenCallback). // // This is the key adapter for Phase 1: any go-inference backend (MLX Metal, // ROCm, llama.cpp) can be wrapped to satisfy go-ml's Backend contract. type InferenceAdapter struct { model inference.TextModel name string } // Compile-time checks. var _ Backend = (*InferenceAdapter)(nil) var _ StreamingBackend = (*InferenceAdapter)(nil) // NewInferenceAdapter wraps a go-inference TextModel as an ml.Backend and // ml.StreamingBackend. The name is used for Backend.Name() (e.g. "mlx"). func NewInferenceAdapter(model inference.TextModel, name string) *InferenceAdapter { return &InferenceAdapter{model: model, name: name} } // Generate collects all tokens from the model's iterator into a single string. func (a *InferenceAdapter) Generate(ctx context.Context, prompt string, opts GenOpts) (Result, error) { inferOpts := convertOpts(opts) var b strings.Builder for tok := range a.model.Generate(ctx, prompt, inferOpts...) { b.WriteString(tok.Text) } if err := a.model.Err(); err != nil { return Result{Text: b.String()}, err } return Result{Text: b.String(), Metrics: metricsPtr(a.model)}, nil } // Chat sends a multi-turn conversation to the underlying TextModel and collects // all tokens. Since ml.Message is now a type alias for inference.Message, no // conversion is needed. func (a *InferenceAdapter) Chat(ctx context.Context, messages []Message, opts GenOpts) (Result, error) { inferOpts := convertOpts(opts) var b strings.Builder for tok := range a.model.Chat(ctx, messages, inferOpts...) { b.WriteString(tok.Text) } if err := a.model.Err(); err != nil { return Result{Text: b.String()}, err } return Result{Text: b.String(), Metrics: metricsPtr(a.model)}, nil } // GenerateStream forwards each generated token's text to the callback. // Returns nil on success, the callback's error if it stops early, or the // model's error if generation fails. func (a *InferenceAdapter) GenerateStream(ctx context.Context, prompt string, opts GenOpts, cb TokenCallback) error { inferOpts := convertOpts(opts) for tok := range a.model.Generate(ctx, prompt, inferOpts...) { if err := cb(tok.Text); err != nil { return err } } return a.model.Err() } // ChatStream forwards each generated chat token's text to the callback. // Since ml.Message is now a type alias for inference.Message, no conversion // is needed. func (a *InferenceAdapter) ChatStream(ctx context.Context, messages []Message, opts GenOpts, cb TokenCallback) error { inferOpts := convertOpts(opts) for tok := range a.model.Chat(ctx, messages, inferOpts...) { if err := cb(tok.Text); err != nil { return err } } return a.model.Err() } // Name returns the backend identifier set at construction. func (a *InferenceAdapter) Name() string { return a.name } // Available always returns true — the model is already loaded. func (a *InferenceAdapter) Available() bool { return true } // Close delegates to the underlying TextModel.Close(), releasing GPU memory // and other resources. func (a *InferenceAdapter) Close() error { return a.model.Close() } // Model returns the underlying go-inference TextModel for direct access // to Classify, BatchGenerate, Metrics, Info, etc. func (a *InferenceAdapter) Model() inference.TextModel { return a.model } // InspectAttention delegates to the underlying TextModel if it implements // inference.AttentionInspector. Returns an error if the backend does not support // attention inspection. func (a *InferenceAdapter) InspectAttention(ctx context.Context, prompt string, opts ...inference.GenerateOption) (*inference.AttentionSnapshot, error) { inspector, ok := a.model.(inference.AttentionInspector) if !ok { return nil, fmt.Errorf("backend %q does not support attention inspection", a.name) } return inspector.InspectAttention(ctx, prompt, opts...) } // convertOpts maps ml.GenOpts to go-inference functional options. func convertOpts(opts GenOpts) []inference.GenerateOption { var out []inference.GenerateOption if opts.Temperature != 0 { out = append(out, inference.WithTemperature(float32(opts.Temperature))) } if opts.MaxTokens != 0 { out = append(out, inference.WithMaxTokens(opts.MaxTokens)) } if opts.TopK > 0 { out = append(out, inference.WithTopK(opts.TopK)) } if opts.TopP > 0 { out = append(out, inference.WithTopP(float32(opts.TopP))) } if opts.RepeatPenalty > 0 { out = append(out, inference.WithRepeatPenalty(float32(opts.RepeatPenalty))) } // GenOpts.Model is ignored — the model is already loaded. return out } // metricsPtr returns a copy of the model's latest metrics, or nil if unavailable. func metricsPtr(m inference.TextModel) *inference.GenerateMetrics { met := m.Metrics() return &met }