feat: add lem worker command for distributed inference network
Go client for the LEM distributed inference API (BugSETI/Agentic). Workers register via Forgejo PAT auth, pull prompt batches, run local inference (MLX/vLLM/llama.cpp), submit results. Credits tracked as Phase 1 stub for Phase 2 blockchain LEM token. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
774f097855
commit
08363ee1af
3 changed files with 656 additions and 0 deletions
5
main.go
5
main.go
|
|
@ -42,6 +42,9 @@ Monitoring:
|
|||
coverage Analyze seed coverage gaps
|
||||
metrics Push DuckDB golden set stats to InfluxDB
|
||||
|
||||
Distributed:
|
||||
worker Run as distributed inference worker node
|
||||
|
||||
Infrastructure:
|
||||
ingest Ingest benchmark data into InfluxDB
|
||||
seed-influx Seed InfluxDB golden_gen from DuckDB
|
||||
|
|
@ -101,6 +104,8 @@ func main() {
|
|||
lem.RunQuery(os.Args[2:])
|
||||
case "agent":
|
||||
lem.RunAgent(os.Args[2:])
|
||||
case "worker":
|
||||
lem.RunWorker(os.Args[2:])
|
||||
default:
|
||||
fmt.Fprintf(os.Stderr, "unknown command: %s\n\n%s", os.Args[1], usage)
|
||||
os.Exit(1)
|
||||
|
|
|
|||
454
pkg/lem/worker.go
Normal file
454
pkg/lem/worker.go
Normal file
|
|
@ -0,0 +1,454 @@
|
|||
package lem
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"time"
|
||||
)
|
||||
|
||||
// workerConfig holds the worker's runtime configuration.
|
||||
type workerConfig struct {
|
||||
apiBase string
|
||||
workerID string
|
||||
name string
|
||||
apiKey string
|
||||
gpuType string
|
||||
vramGb int
|
||||
languages []string
|
||||
models []string
|
||||
inferURL string
|
||||
taskType string
|
||||
batchSize int
|
||||
pollInterval time.Duration
|
||||
oneShot bool
|
||||
dryRun bool
|
||||
}
|
||||
|
||||
// apiTask represents a task from the LEM API.
|
||||
type apiTask struct {
|
||||
ID int `json:"id"`
|
||||
TaskType string `json:"task_type"`
|
||||
Status string `json:"status"`
|
||||
Language string `json:"language"`
|
||||
Domain string `json:"domain"`
|
||||
ModelName string `json:"model_name"`
|
||||
PromptID string `json:"prompt_id"`
|
||||
PromptText string `json:"prompt_text"`
|
||||
Config *struct {
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
} `json:"config"`
|
||||
Priority int `json:"priority"`
|
||||
}
|
||||
|
||||
// RunWorker is the CLI entry point for `lem worker`.
|
||||
func RunWorker(args []string) {
|
||||
fs := flag.NewFlagSet("worker", flag.ExitOnError)
|
||||
|
||||
cfg := workerConfig{}
|
||||
var langs, mods string
|
||||
|
||||
fs.StringVar(&cfg.apiBase, "api", envOr("LEM_API", "https://infer.lthn.ai"), "LEM API base URL")
|
||||
fs.StringVar(&cfg.workerID, "id", envOr("LEM_WORKER_ID", machineID()), "Worker ID (machine UUID)")
|
||||
fs.StringVar(&cfg.name, "name", envOr("LEM_WORKER_NAME", hostname()), "Worker display name")
|
||||
fs.StringVar(&cfg.apiKey, "key", envOr("LEM_API_KEY", ""), "API key (or use LEM_API_KEY env)")
|
||||
fs.StringVar(&cfg.gpuType, "gpu", envOr("LEM_GPU", ""), "GPU type (e.g. 'RTX 3090')")
|
||||
fs.IntVar(&cfg.vramGb, "vram", intEnvOr("LEM_VRAM_GB", 0), "GPU VRAM in GB")
|
||||
fs.StringVar(&langs, "languages", envOr("LEM_LANGUAGES", ""), "Comma-separated language codes (e.g. 'en,yo,sw')")
|
||||
fs.StringVar(&mods, "models", envOr("LEM_MODELS", ""), "Comma-separated supported model names")
|
||||
fs.StringVar(&cfg.inferURL, "infer", envOr("LEM_INFER_URL", "http://localhost:8090"), "Local inference endpoint")
|
||||
fs.StringVar(&cfg.taskType, "type", "", "Filter by task type (expand, score, translate, seed)")
|
||||
fs.IntVar(&cfg.batchSize, "batch", 5, "Number of tasks to fetch per poll")
|
||||
dur := fs.Duration("poll", 30*time.Second, "Poll interval")
|
||||
fs.BoolVar(&cfg.oneShot, "one-shot", false, "Process one batch and exit")
|
||||
fs.BoolVar(&cfg.dryRun, "dry-run", false, "Fetch tasks but don't run inference")
|
||||
|
||||
fs.Parse(args)
|
||||
cfg.pollInterval = *dur
|
||||
|
||||
if langs != "" {
|
||||
cfg.languages = splitComma(langs)
|
||||
}
|
||||
if mods != "" {
|
||||
cfg.models = splitComma(mods)
|
||||
}
|
||||
|
||||
if cfg.apiKey == "" {
|
||||
cfg.apiKey = readKeyFile()
|
||||
}
|
||||
|
||||
if cfg.apiKey == "" {
|
||||
log.Fatal("No API key. Set LEM_API_KEY or run: lem worker-auth")
|
||||
}
|
||||
|
||||
log.Printf("LEM Worker starting")
|
||||
log.Printf(" ID: %s", cfg.workerID)
|
||||
log.Printf(" Name: %s", cfg.name)
|
||||
log.Printf(" API: %s", cfg.apiBase)
|
||||
log.Printf(" Infer: %s", cfg.inferURL)
|
||||
log.Printf(" GPU: %s (%d GB)", cfg.gpuType, cfg.vramGb)
|
||||
log.Printf(" Langs: %v", cfg.languages)
|
||||
log.Printf(" Models: %v", cfg.models)
|
||||
log.Printf(" Batch: %d", cfg.batchSize)
|
||||
log.Printf(" Dry-run: %v", cfg.dryRun)
|
||||
|
||||
// Register with the API.
|
||||
if err := workerRegister(&cfg); err != nil {
|
||||
log.Fatalf("Registration failed: %v", err)
|
||||
}
|
||||
log.Println("Registered with LEM API")
|
||||
|
||||
// Main loop.
|
||||
for {
|
||||
processed := workerPoll(&cfg)
|
||||
|
||||
if cfg.oneShot {
|
||||
log.Printf("One-shot mode: processed %d tasks, exiting", processed)
|
||||
return
|
||||
}
|
||||
|
||||
if processed == 0 {
|
||||
log.Printf("No tasks available, sleeping %v", cfg.pollInterval)
|
||||
time.Sleep(cfg.pollInterval)
|
||||
}
|
||||
|
||||
// Heartbeat.
|
||||
workerHeartbeat(&cfg)
|
||||
}
|
||||
}
|
||||
|
||||
func workerRegister(cfg *workerConfig) error {
|
||||
body := map[string]interface{}{
|
||||
"worker_id": cfg.workerID,
|
||||
"name": cfg.name,
|
||||
"version": "0.1.0",
|
||||
"os": runtime.GOOS,
|
||||
"arch": runtime.GOARCH,
|
||||
}
|
||||
if cfg.gpuType != "" {
|
||||
body["gpu_type"] = cfg.gpuType
|
||||
}
|
||||
if cfg.vramGb > 0 {
|
||||
body["vram_gb"] = cfg.vramGb
|
||||
}
|
||||
if len(cfg.languages) > 0 {
|
||||
body["languages"] = cfg.languages
|
||||
}
|
||||
if len(cfg.models) > 0 {
|
||||
body["supported_models"] = cfg.models
|
||||
}
|
||||
|
||||
_, err := apiPost(cfg, "/api/lem/workers/register", body)
|
||||
return err
|
||||
}
|
||||
|
||||
func workerHeartbeat(cfg *workerConfig) {
|
||||
body := map[string]interface{}{
|
||||
"worker_id": cfg.workerID,
|
||||
}
|
||||
apiPost(cfg, "/api/lem/workers/heartbeat", body)
|
||||
}
|
||||
|
||||
func workerPoll(cfg *workerConfig) int {
|
||||
// Fetch available tasks.
|
||||
url := fmt.Sprintf("/api/lem/tasks/next?worker_id=%s&limit=%d", cfg.workerID, cfg.batchSize)
|
||||
if cfg.taskType != "" {
|
||||
url += "&type=" + cfg.taskType
|
||||
}
|
||||
|
||||
resp, err := apiGet(cfg, url)
|
||||
if err != nil {
|
||||
log.Printf("Error fetching tasks: %v", err)
|
||||
return 0
|
||||
}
|
||||
|
||||
var result struct {
|
||||
Tasks []apiTask `json:"tasks"`
|
||||
Count int `json:"count"`
|
||||
}
|
||||
if err := json.Unmarshal(resp, &result); err != nil {
|
||||
log.Printf("Error parsing tasks: %v", err)
|
||||
return 0
|
||||
}
|
||||
|
||||
if result.Count == 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
log.Printf("Got %d tasks", result.Count)
|
||||
processed := 0
|
||||
|
||||
for _, task := range result.Tasks {
|
||||
if err := workerProcessTask(cfg, task); err != nil {
|
||||
log.Printf("Task %d failed: %v", task.ID, err)
|
||||
// Release the claim so someone else can try.
|
||||
apiDelete(cfg, fmt.Sprintf("/api/lem/tasks/%d/claim", task.ID), map[string]interface{}{
|
||||
"worker_id": cfg.workerID,
|
||||
})
|
||||
continue
|
||||
}
|
||||
processed++
|
||||
}
|
||||
|
||||
return processed
|
||||
}
|
||||
|
||||
func workerProcessTask(cfg *workerConfig, task apiTask) error {
|
||||
log.Printf("Processing task %d: %s [%s/%s] %d chars prompt",
|
||||
task.ID, task.TaskType, task.Language, task.Domain, len(task.PromptText))
|
||||
|
||||
// Claim the task.
|
||||
_, err := apiPost(cfg, fmt.Sprintf("/api/lem/tasks/%d/claim", task.ID), map[string]interface{}{
|
||||
"worker_id": cfg.workerID,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("claim: %w", err)
|
||||
}
|
||||
|
||||
// Update to in_progress.
|
||||
apiPatch(cfg, fmt.Sprintf("/api/lem/tasks/%d/status", task.ID), map[string]interface{}{
|
||||
"worker_id": cfg.workerID,
|
||||
"status": "in_progress",
|
||||
})
|
||||
|
||||
if cfg.dryRun {
|
||||
log.Printf(" [DRY-RUN] Would generate response for: %.80s...", task.PromptText)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Run inference via local API.
|
||||
start := time.Now()
|
||||
response, err := workerInfer(cfg, task)
|
||||
genTime := time.Since(start)
|
||||
|
||||
if err != nil {
|
||||
// Report failure, release task.
|
||||
apiPatch(cfg, fmt.Sprintf("/api/lem/tasks/%d/status", task.ID), map[string]interface{}{
|
||||
"worker_id": cfg.workerID,
|
||||
"status": "abandoned",
|
||||
})
|
||||
return fmt.Errorf("inference: %w", err)
|
||||
}
|
||||
|
||||
// Submit result.
|
||||
modelUsed := task.ModelName
|
||||
if modelUsed == "" {
|
||||
modelUsed = "default"
|
||||
}
|
||||
|
||||
_, err = apiPost(cfg, fmt.Sprintf("/api/lem/tasks/%d/result", task.ID), map[string]interface{}{
|
||||
"worker_id": cfg.workerID,
|
||||
"response_text": response,
|
||||
"model_used": modelUsed,
|
||||
"gen_time_ms": int(genTime.Milliseconds()),
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("submit result: %w", err)
|
||||
}
|
||||
|
||||
log.Printf(" Completed: %d chars in %v", len(response), genTime.Round(time.Millisecond))
|
||||
return nil
|
||||
}
|
||||
|
||||
func workerInfer(cfg *workerConfig, task apiTask) (string, error) {
|
||||
// Build the chat request for the local inference endpoint.
|
||||
messages := []map[string]string{
|
||||
{"role": "user", "content": task.PromptText},
|
||||
}
|
||||
|
||||
temp := 0.7
|
||||
maxTokens := 2048
|
||||
if task.Config != nil {
|
||||
if task.Config.Temperature > 0 {
|
||||
temp = task.Config.Temperature
|
||||
}
|
||||
if task.Config.MaxTokens > 0 {
|
||||
maxTokens = task.Config.MaxTokens
|
||||
}
|
||||
}
|
||||
|
||||
reqBody := map[string]interface{}{
|
||||
"model": task.ModelName,
|
||||
"messages": messages,
|
||||
"temperature": temp,
|
||||
"max_tokens": maxTokens,
|
||||
}
|
||||
|
||||
data, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("POST", cfg.inferURL+"/v1/chat/completions", bytes.NewReader(data))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
client := &http.Client{Timeout: 5 * time.Minute}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("inference request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("read response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
return "", fmt.Errorf("inference HTTP %d: %s", resp.StatusCode, truncStr(string(body), 200))
|
||||
}
|
||||
|
||||
// Parse OpenAI-compatible response.
|
||||
var chatResp struct {
|
||||
Choices []struct {
|
||||
Message struct {
|
||||
Content string `json:"content"`
|
||||
} `json:"message"`
|
||||
} `json:"choices"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &chatResp); err != nil {
|
||||
return "", fmt.Errorf("parse response: %w", err)
|
||||
}
|
||||
|
||||
if len(chatResp.Choices) == 0 {
|
||||
return "", fmt.Errorf("no choices in response")
|
||||
}
|
||||
|
||||
content := chatResp.Choices[0].Message.Content
|
||||
if len(content) < 10 {
|
||||
return "", fmt.Errorf("response too short: %d chars", len(content))
|
||||
}
|
||||
|
||||
return content, nil
|
||||
}
|
||||
|
||||
// HTTP helpers for the LEM API.
|
||||
|
||||
func apiGet(cfg *workerConfig, path string) ([]byte, error) {
|
||||
req, err := http.NewRequest("GET", cfg.apiBase+path, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+cfg.apiKey)
|
||||
|
||||
client := &http.Client{Timeout: 15 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
return nil, fmt.Errorf("HTTP %d: %s", resp.StatusCode, truncStr(string(body), 200))
|
||||
}
|
||||
|
||||
return body, nil
|
||||
}
|
||||
|
||||
func apiPost(cfg *workerConfig, path string, data map[string]interface{}) ([]byte, error) {
|
||||
return apiRequest(cfg, "POST", path, data)
|
||||
}
|
||||
|
||||
func apiPatch(cfg *workerConfig, path string, data map[string]interface{}) ([]byte, error) {
|
||||
return apiRequest(cfg, "PATCH", path, data)
|
||||
}
|
||||
|
||||
func apiDelete(cfg *workerConfig, path string, data map[string]interface{}) ([]byte, error) {
|
||||
return apiRequest(cfg, "DELETE", path, data)
|
||||
}
|
||||
|
||||
func apiRequest(cfg *workerConfig, method, path string, data map[string]interface{}) ([]byte, error) {
|
||||
jsonData, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req, err := http.NewRequest(method, cfg.apiBase+path, bytes.NewReader(jsonData))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+cfg.apiKey)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
client := &http.Client{Timeout: 15 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
return nil, fmt.Errorf("HTTP %d: %s", resp.StatusCode, truncStr(string(body), 200))
|
||||
}
|
||||
|
||||
return body, nil
|
||||
}
|
||||
|
||||
// Utility functions.
|
||||
|
||||
func machineID() string {
|
||||
// Try /etc/machine-id (Linux), then hostname fallback.
|
||||
if data, err := os.ReadFile("/etc/machine-id"); err == nil {
|
||||
id := string(bytes.TrimSpace(data))
|
||||
if len(id) > 0 {
|
||||
return id
|
||||
}
|
||||
}
|
||||
h, _ := os.Hostname()
|
||||
return h
|
||||
}
|
||||
|
||||
func hostname() string {
|
||||
h, _ := os.Hostname()
|
||||
return h
|
||||
}
|
||||
|
||||
func readKeyFile() string {
|
||||
home, _ := os.UserHomeDir()
|
||||
path := filepath.Join(home, ".config", "lem", "api_key")
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return string(bytes.TrimSpace(data))
|
||||
}
|
||||
|
||||
func splitComma(s string) []string {
|
||||
var result []string
|
||||
for _, part := range bytes.Split([]byte(s), []byte(",")) {
|
||||
trimmed := bytes.TrimSpace(part)
|
||||
if len(trimmed) > 0 {
|
||||
result = append(result, string(trimmed))
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func truncStr(s string, n int) string {
|
||||
maxLen := n
|
||||
if len(s) <= maxLen {
|
||||
return s
|
||||
}
|
||||
return s[:maxLen] + "..."
|
||||
}
|
||||
197
pkg/lem/worker_test.go
Normal file
197
pkg/lem/worker_test.go
Normal file
|
|
@ -0,0 +1,197 @@
|
|||
package lem
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestMachineID(t *testing.T) {
|
||||
id := machineID()
|
||||
if id == "" {
|
||||
t.Error("machineID returned empty string")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHostname(t *testing.T) {
|
||||
h := hostname()
|
||||
if h == "" {
|
||||
t.Error("hostname returned empty string")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSplitComma(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
want int
|
||||
}{
|
||||
{"en,yo,sw", 3},
|
||||
{"en", 1},
|
||||
{"", 0},
|
||||
{"en, yo , sw", 3},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
got := splitComma(tt.input)
|
||||
if len(got) != tt.want {
|
||||
t.Errorf("splitComma(%q) = %d items, want %d", tt.input, len(got), tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestTruncStr(t *testing.T) {
|
||||
if got := truncStr("hello", 10); got != "hello" {
|
||||
t.Errorf("truncStr short = %q", got)
|
||||
}
|
||||
if got := truncStr("hello world", 5); got != "hello..." {
|
||||
t.Errorf("truncStr long = %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWorkerRegister(t *testing.T) {
|
||||
var gotBody map[string]interface{}
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/api/lem/workers/register" {
|
||||
t.Errorf("unexpected path: %s", r.URL.Path)
|
||||
}
|
||||
if r.Header.Get("Authorization") != "Bearer test-key" {
|
||||
t.Errorf("missing auth header")
|
||||
}
|
||||
json.NewDecoder(r.Body).Decode(&gotBody)
|
||||
w.WriteHeader(201)
|
||||
w.Write([]byte(`{"worker":{}}`))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
cfg := &workerConfig{
|
||||
apiBase: srv.URL,
|
||||
workerID: "test-worker-001",
|
||||
name: "Test Worker",
|
||||
apiKey: "test-key",
|
||||
gpuType: "RTX 3090",
|
||||
vramGb: 24,
|
||||
languages: []string{"en", "yo"},
|
||||
models: []string{"gemma-3-12b"},
|
||||
}
|
||||
|
||||
err := workerRegister(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("register failed: %v", err)
|
||||
}
|
||||
|
||||
if gotBody["worker_id"] != "test-worker-001" {
|
||||
t.Errorf("worker_id = %v", gotBody["worker_id"])
|
||||
}
|
||||
if gotBody["gpu_type"] != "RTX 3090" {
|
||||
t.Errorf("gpu_type = %v", gotBody["gpu_type"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestWorkerPoll(t *testing.T) {
|
||||
callCount := 0
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
callCount++
|
||||
switch {
|
||||
case r.URL.Path == "/api/lem/tasks/next":
|
||||
// Return one task.
|
||||
json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"tasks": []apiTask{
|
||||
{
|
||||
ID: 42,
|
||||
TaskType: "expand",
|
||||
Language: "yo",
|
||||
Domain: "sovereignty",
|
||||
PromptText: "What does sovereignty mean to you?",
|
||||
ModelName: "gemma-3-12b",
|
||||
Priority: 100,
|
||||
},
|
||||
},
|
||||
"count": 1,
|
||||
})
|
||||
case r.URL.Path == "/api/lem/tasks/42/claim" && r.Method == "POST":
|
||||
w.WriteHeader(201)
|
||||
w.Write([]byte(`{"task":{}}`))
|
||||
case r.URL.Path == "/api/lem/tasks/42/status" && r.Method == "PATCH":
|
||||
w.Write([]byte(`{"task":{}}`))
|
||||
default:
|
||||
w.WriteHeader(404)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
cfg := &workerConfig{
|
||||
apiBase: srv.URL,
|
||||
workerID: "test-worker",
|
||||
apiKey: "test-key",
|
||||
batchSize: 5,
|
||||
dryRun: true, // don't actually run inference
|
||||
}
|
||||
|
||||
processed := workerPoll(cfg)
|
||||
if processed != 1 {
|
||||
t.Errorf("processed = %d, want 1", processed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWorkerInfer(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/v1/chat/completions" {
|
||||
t.Errorf("unexpected path: %s", r.URL.Path)
|
||||
}
|
||||
|
||||
var body map[string]interface{}
|
||||
json.NewDecoder(r.Body).Decode(&body)
|
||||
|
||||
if body["temperature"].(float64) != 0.7 {
|
||||
t.Errorf("temperature = %v", body["temperature"])
|
||||
}
|
||||
|
||||
json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"choices": []map[string]interface{}{
|
||||
{
|
||||
"message": map[string]string{
|
||||
"content": "Sovereignty means the inherent right of every individual to self-determination...",
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
cfg := &workerConfig{
|
||||
inferURL: srv.URL,
|
||||
}
|
||||
|
||||
task := apiTask{
|
||||
ID: 1,
|
||||
PromptText: "What does sovereignty mean?",
|
||||
ModelName: "gemma-3-12b",
|
||||
}
|
||||
|
||||
response, err := workerInfer(cfg, task)
|
||||
if err != nil {
|
||||
t.Fatalf("infer failed: %v", err)
|
||||
}
|
||||
|
||||
if len(response) < 10 {
|
||||
t.Errorf("response too short: %d chars", len(response))
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIErrorHandling(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(500)
|
||||
w.Write([]byte(`{"error":"internal server error"}`))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
cfg := &workerConfig{
|
||||
apiBase: srv.URL,
|
||||
apiKey: "test-key",
|
||||
}
|
||||
|
||||
_, err := apiGet(cfg, "/api/lem/test")
|
||||
if err == nil {
|
||||
t.Error("expected error for 500 response")
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Reference in a new issue