feat: Phase 2 backend consolidation — Message alias, GenOpts, deprecation
- Replace Message struct with type alias for inference.Message (backward compat) - Remove convertMessages() — types are now identical via alias - Extend GenOpts with TopK, TopP, RepeatPenalty (mapped in convertOpts) - Deprecate StreamingBackend with doc comment (only 2 callers, both in cli/) - Simplify HTTPTextModel.Chat() — pass messages directly - Update CLAUDE.md with Backend Architecture section - Add 2 new tests, remove 1 obsolete test Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
parent
fe70eb0ce1
commit
747e703c7b
5 changed files with 111 additions and 56 deletions
41
CLAUDE.md
41
CLAUDE.md
|
|
@ -94,10 +94,33 @@ All resolve via `replace` directives in go.mod:
|
|||
| `db.go` | 258 | DuckDB analytics storage |
|
||||
| `gguf.go` | 369 | GGUF model format parsing |
|
||||
|
||||
### Backend Architecture
|
||||
|
||||
Two interface families coexist, bridged by adapters:
|
||||
|
||||
**`inference.TextModel`** (iterator-based) is the **preferred API** for new code. Returns `iter.Seq[inference.Token]` for streaming. Defined in `forge.lthn.ai/core/go-inference`. Use this for GPU backends (MLX Metal, ROCm) and any code that needs token-level control.
|
||||
|
||||
**`ml.Backend`** (string-based) is the **compatibility layer**, still fully supported. Returns complete strings. Used by `service.go`, `judge.go`, and external consumers like `host-uk/cli`.
|
||||
|
||||
**`ml.StreamingBackend`** is **deprecated**. New code should use `inference.TextModel` with `iter.Seq[Token]` directly. Retained for backward compatibility with existing callers.
|
||||
|
||||
**Adapters:**
|
||||
|
||||
| Adapter | Direction | File |
|
||||
|---------|-----------|------|
|
||||
| `InferenceAdapter` | `inference.TextModel` -> `ml.Backend` + `ml.StreamingBackend` | `adapter.go` |
|
||||
| `HTTPTextModel` | `ml.HTTPBackend` -> `inference.TextModel` | `backend_http_textmodel.go` |
|
||||
| `LlamaTextModel` | `ml.LlamaBackend` -> `inference.TextModel` | `backend_http_textmodel.go` |
|
||||
|
||||
**Unified types (Phase 2):**
|
||||
|
||||
- `ml.Message` is a type alias for `inference.Message` — the types are identical, no conversion needed between packages.
|
||||
- `ml.GenOpts` extends `inference.GenerateConfig` with a `Model` field for per-request model overrides. The `convertOpts()` helper maps GenOpts to `[]inference.GenerateOption`.
|
||||
|
||||
### Key Types
|
||||
|
||||
```go
|
||||
// Current backend interface (inference.go)
|
||||
// Backend interface (inference.go) — compatibility layer
|
||||
type Backend interface {
|
||||
Generate(ctx context.Context, prompt string, opts GenOpts) (string, error)
|
||||
Chat(ctx context.Context, messages []Message, opts GenOpts) (string, error)
|
||||
|
|
@ -105,6 +128,7 @@ type Backend interface {
|
|||
Available() bool
|
||||
}
|
||||
|
||||
// Deprecated: use inference.TextModel with iter.Seq[Token] directly
|
||||
type StreamingBackend interface {
|
||||
Backend
|
||||
GenerateStream(ctx context.Context, prompt string, opts GenOpts, cb TokenCallback) error
|
||||
|
|
@ -112,15 +136,16 @@ type StreamingBackend interface {
|
|||
}
|
||||
|
||||
type GenOpts struct {
|
||||
Temperature float64
|
||||
MaxTokens int
|
||||
Model string
|
||||
Temperature float64
|
||||
MaxTokens int
|
||||
Model string // override model for this request
|
||||
TopK int // top-k sampling (0 = disabled)
|
||||
TopP float64 // nucleus sampling threshold (0 = disabled)
|
||||
RepeatPenalty float64 // repetition penalty (0 = disabled, 1.0 = no penalty)
|
||||
}
|
||||
|
||||
type Message struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
// Type alias — identical to inference.Message
|
||||
type Message = inference.Message
|
||||
```
|
||||
|
||||
## Coding Standards
|
||||
|
|
|
|||
30
adapter.go
30
adapter.go
|
|
@ -42,12 +42,13 @@ func (a *InferenceAdapter) Generate(ctx context.Context, prompt string, opts Gen
|
|||
return b.String(), nil
|
||||
}
|
||||
|
||||
// Chat converts ml.Message to inference.Message, then collects all tokens.
|
||||
// Chat sends a multi-turn conversation to the underlying TextModel and collects
|
||||
// all tokens. Since ml.Message is now a type alias for inference.Message, no
|
||||
// conversion is needed.
|
||||
func (a *InferenceAdapter) Chat(ctx context.Context, messages []Message, opts GenOpts) (string, error) {
|
||||
inferMsgs := convertMessages(messages)
|
||||
inferOpts := convertOpts(opts)
|
||||
var b strings.Builder
|
||||
for tok := range a.model.Chat(ctx, inferMsgs, inferOpts...) {
|
||||
for tok := range a.model.Chat(ctx, messages, inferOpts...) {
|
||||
b.WriteString(tok.Text)
|
||||
}
|
||||
if err := a.model.Err(); err != nil {
|
||||
|
|
@ -70,10 +71,11 @@ func (a *InferenceAdapter) GenerateStream(ctx context.Context, prompt string, op
|
|||
}
|
||||
|
||||
// ChatStream forwards each generated chat token's text to the callback.
|
||||
// Since ml.Message is now a type alias for inference.Message, no conversion
|
||||
// is needed.
|
||||
func (a *InferenceAdapter) ChatStream(ctx context.Context, messages []Message, opts GenOpts, cb TokenCallback) error {
|
||||
inferMsgs := convertMessages(messages)
|
||||
inferOpts := convertOpts(opts)
|
||||
for tok := range a.model.Chat(ctx, inferMsgs, inferOpts...) {
|
||||
for tok := range a.model.Chat(ctx, messages, inferOpts...) {
|
||||
if err := cb(tok.Text); err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -104,15 +106,15 @@ func convertOpts(opts GenOpts) []inference.GenerateOption {
|
|||
if opts.MaxTokens != 0 {
|
||||
out = append(out, inference.WithMaxTokens(opts.MaxTokens))
|
||||
}
|
||||
if opts.TopK > 0 {
|
||||
out = append(out, inference.WithTopK(opts.TopK))
|
||||
}
|
||||
if opts.TopP > 0 {
|
||||
out = append(out, inference.WithTopP(float32(opts.TopP)))
|
||||
}
|
||||
if opts.RepeatPenalty > 0 {
|
||||
out = append(out, inference.WithRepeatPenalty(float32(opts.RepeatPenalty)))
|
||||
}
|
||||
// GenOpts.Model is ignored — the model is already loaded.
|
||||
return out
|
||||
}
|
||||
|
||||
// convertMessages maps ml.Message to inference.Message (trivial field copy).
|
||||
func convertMessages(msgs []Message) []inference.Message {
|
||||
out := make([]inference.Message, len(msgs))
|
||||
for i, m := range msgs {
|
||||
out[i] = inference.Message{Role: m.Role, Content: m.Content}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
|
|
|||
|
|
@ -211,20 +211,49 @@ func TestInferenceAdapter_ConvertOpts_Good(t *testing.T) {
|
|||
assert.Len(t, opts, 1)
|
||||
}
|
||||
|
||||
func TestInferenceAdapter_ConvertMessages_Good(t *testing.T) {
|
||||
mlMsgs := []Message{
|
||||
{Role: "system", Content: "You are helpful."},
|
||||
{Role: "user", Content: "Hello"},
|
||||
{Role: "assistant", Content: "Hi!"},
|
||||
}
|
||||
inferMsgs := convertMessages(mlMsgs)
|
||||
require.Len(t, inferMsgs, 3)
|
||||
assert.Equal(t, "system", inferMsgs[0].Role)
|
||||
assert.Equal(t, "You are helpful.", inferMsgs[0].Content)
|
||||
assert.Equal(t, "user", inferMsgs[1].Role)
|
||||
assert.Equal(t, "Hello", inferMsgs[1].Content)
|
||||
assert.Equal(t, "assistant", inferMsgs[2].Role)
|
||||
assert.Equal(t, "Hi!", inferMsgs[2].Content)
|
||||
func TestInferenceAdapter_ConvertOpts_NewFields_Good(t *testing.T) {
|
||||
// TopK only.
|
||||
opts := convertOpts(GenOpts{TopK: 40})
|
||||
assert.Len(t, opts, 1)
|
||||
|
||||
// TopP only.
|
||||
opts = convertOpts(GenOpts{TopP: 0.9})
|
||||
assert.Len(t, opts, 1)
|
||||
|
||||
// RepeatPenalty only.
|
||||
opts = convertOpts(GenOpts{RepeatPenalty: 1.1})
|
||||
assert.Len(t, opts, 1)
|
||||
|
||||
// All new fields set together.
|
||||
opts = convertOpts(GenOpts{TopK: 40, TopP: 0.9, RepeatPenalty: 1.1})
|
||||
assert.Len(t, opts, 3)
|
||||
|
||||
// All fields set (Temperature + MaxTokens + TopK + TopP + RepeatPenalty).
|
||||
opts = convertOpts(GenOpts{
|
||||
Temperature: 0.7,
|
||||
MaxTokens: 512,
|
||||
TopK: 40,
|
||||
TopP: 0.9,
|
||||
RepeatPenalty: 1.1,
|
||||
})
|
||||
assert.Len(t, opts, 5)
|
||||
|
||||
// Zero TopK/TopP/RepeatPenalty should not produce options.
|
||||
opts = convertOpts(GenOpts{Temperature: 0.5, TopK: 0, TopP: 0, RepeatPenalty: 0})
|
||||
assert.Len(t, opts, 1) // only Temperature
|
||||
}
|
||||
|
||||
func TestInferenceAdapter_MessageAlias_Good(t *testing.T) {
|
||||
// ml.Message and inference.Message are the same type — verify interchangeability.
|
||||
mlMsg := Message{Role: "user", Content: "Hello"}
|
||||
inferMsg := inference.Message{Role: "user", Content: "Hello"}
|
||||
assert.Equal(t, mlMsg, inferMsg)
|
||||
|
||||
// Can assign directly without conversion.
|
||||
var msgs []inference.Message
|
||||
msgs = append(msgs, mlMsg)
|
||||
assert.Equal(t, "user", msgs[0].Role)
|
||||
assert.Equal(t, "Hello", msgs[0].Content)
|
||||
}
|
||||
|
||||
func TestInferenceAdapter_NameAndAvailable_Good(t *testing.T) {
|
||||
|
|
|
|||
|
|
@ -60,13 +60,8 @@ func (m *HTTPTextModel) Chat(ctx context.Context, messages []inference.Message,
|
|||
Model: m.http.Model(),
|
||||
}
|
||||
|
||||
// Convert inference.Message to ml.Message.
|
||||
mlMsgs := make([]Message, len(messages))
|
||||
for i, msg := range messages {
|
||||
mlMsgs[i] = Message{Role: msg.Role, Content: msg.Content}
|
||||
}
|
||||
|
||||
result, err := m.http.Chat(ctx, mlMsgs, genOpts)
|
||||
// ml.Message is now a type alias for inference.Message — no conversion needed.
|
||||
result, err := m.http.Chat(ctx, messages, genOpts)
|
||||
if err != nil {
|
||||
m.lastErr = err
|
||||
return
|
||||
|
|
|
|||
30
inference.go
30
inference.go
|
|
@ -11,7 +11,11 @@
|
|||
// )
|
||||
package ml
|
||||
|
||||
import "context"
|
||||
import (
|
||||
"context"
|
||||
|
||||
"forge.lthn.ai/core/go-inference"
|
||||
)
|
||||
|
||||
// Backend generates text from prompts. Implementations include HTTPBackend
|
||||
// (OpenAI-compatible API), LlamaBackend (managed llama-server process), and
|
||||
|
|
@ -32,25 +36,25 @@ type Backend interface {
|
|||
|
||||
// GenOpts configures a generation request.
|
||||
type GenOpts struct {
|
||||
Temperature float64
|
||||
MaxTokens int
|
||||
Model string // override model for this request
|
||||
Temperature float64
|
||||
MaxTokens int
|
||||
Model string // override model for this request
|
||||
TopK int // top-k sampling (0 = disabled)
|
||||
TopP float64 // nucleus sampling threshold (0 = disabled)
|
||||
RepeatPenalty float64 // repetition penalty (0 = disabled, 1.0 = no penalty)
|
||||
}
|
||||
|
||||
// Message is a single chat message.
|
||||
type Message struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
// Message is a type alias for inference.Message, providing backward compatibility.
|
||||
// All callers continue using ml.Message — it is the same underlying type.
|
||||
type Message = inference.Message
|
||||
|
||||
// TokenCallback receives each generated token as text. Return a non-nil
|
||||
// error to stop generation early (e.g. client disconnect).
|
||||
type TokenCallback func(token string) error
|
||||
|
||||
// StreamingBackend extends Backend with token-by-token streaming.
|
||||
// Backends that generate tokens incrementally (e.g. MLX) should implement
|
||||
// this interface. The serve handler uses SSE when the client sends
|
||||
// "stream": true and the active backend satisfies StreamingBackend.
|
||||
// Deprecated: StreamingBackend is retained for backward compatibility.
|
||||
// New code should use inference.TextModel with iter.Seq[Token] directly.
|
||||
// See InferenceAdapter for the bridge pattern.
|
||||
type StreamingBackend interface {
|
||||
Backend
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue