LEM/pkg/lem/worker.go
Snider 3606ff994b fix: memory, error handling, and signal improvements across pkg/lem
- Stream parquet export rows instead of unbounded memory allocation
- Replace QueryGoldenSet/QueryExpansionPrompts with iter.Seq2 iterators
- Remove legacy runtime.GC() calls from distill (go-mlx handles cleanup)
- Replace log.Fatalf with error return in tier_score.go
- Add SIGINT/SIGTERM signal handling to agent and worker daemon loops
- Add error checks for unchecked db.conn.Exec in import.go and tier_score.go
- Update tests for iterator-based database methods

Co-Authored-By: Gemini <noreply@google.com>
Co-Authored-By: Virgil <virgil@lethean.io>
2026-02-23 04:46:51 +00:00

521 lines
12 KiB
Go

package lem
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"os"
"os/signal"
"path/filepath"
"runtime"
"syscall"
"time"
)
// WorkerOpts holds the worker's runtime configuration.
type WorkerOpts struct {
APIBase string // LEM API base URL
WorkerID string // Worker ID (machine UUID)
Name string // Worker display name
APIKey string // API key (or use LEM_API_KEY env)
GPUType string // GPU type (e.g. 'RTX 3090')
VRAMGb int // GPU VRAM in GB
Languages string // Comma-separated language codes (e.g. 'en,yo,sw')
Models string // Comma-separated supported model names
InferURL string // Local inference endpoint
TaskType string // Filter by task type (expand, score, translate, seed)
BatchSize int // Number of tasks to fetch per poll
PollInterval time.Duration // Poll interval
OneShot bool // Process one batch and exit
DryRun bool // Fetch tasks but don't run inference
}
// workerConfig is the internal runtime config populated from WorkerOpts.
type workerConfig struct {
apiBase string
workerID string
name string
apiKey string
gpuType string
vramGb int
languages []string
models []string
inferURL string
taskType string
batchSize int
pollInterval time.Duration
oneShot bool
dryRun bool
}
// apiTask represents a task from the LEM API.
type apiTask struct {
ID int `json:"id"`
TaskType string `json:"task_type"`
Status string `json:"status"`
Language string `json:"language"`
Domain string `json:"domain"`
ModelName string `json:"model_name"`
PromptID string `json:"prompt_id"`
PromptText string `json:"prompt_text"`
Config *struct {
Temperature float64 `json:"temperature,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
} `json:"config"`
Priority int `json:"priority"`
}
// RunWorker is the CLI entry point for `lem worker`.
func RunWorker(opts WorkerOpts) error {
// Apply env-var fallbacks for fields not set by flags.
if opts.APIBase == "" {
opts.APIBase = envOr("LEM_API", "https://infer.lthn.ai")
}
if opts.WorkerID == "" {
opts.WorkerID = envOr("LEM_WORKER_ID", machineID())
}
if opts.Name == "" {
opts.Name = envOr("LEM_WORKER_NAME", hostname())
}
if opts.APIKey == "" {
opts.APIKey = envOr("LEM_API_KEY", "")
}
if opts.GPUType == "" {
opts.GPUType = envOr("LEM_GPU", "")
}
if opts.VRAMGb == 0 {
opts.VRAMGb = intEnvOr("LEM_VRAM_GB", 0)
}
if opts.Languages == "" {
opts.Languages = envOr("LEM_LANGUAGES", "")
}
if opts.Models == "" {
opts.Models = envOr("LEM_MODELS", "")
}
if opts.InferURL == "" {
opts.InferURL = envOr("LEM_INFER_URL", "http://localhost:8090")
}
if opts.BatchSize <= 0 {
opts.BatchSize = 5
}
if opts.PollInterval <= 0 {
opts.PollInterval = 30 * time.Second
}
// Build internal config.
cfg := workerConfig{
apiBase: opts.APIBase,
workerID: opts.WorkerID,
name: opts.Name,
apiKey: opts.APIKey,
gpuType: opts.GPUType,
vramGb: opts.VRAMGb,
inferURL: opts.InferURL,
taskType: opts.TaskType,
batchSize: opts.BatchSize,
pollInterval: opts.PollInterval,
oneShot: opts.OneShot,
dryRun: opts.DryRun,
}
if opts.Languages != "" {
cfg.languages = splitComma(opts.Languages)
}
if opts.Models != "" {
cfg.models = splitComma(opts.Models)
}
if cfg.apiKey == "" {
cfg.apiKey = readKeyFile()
}
if cfg.apiKey == "" {
return fmt.Errorf("no API key — set LEM_API_KEY or run: lem worker-auth")
}
log.Printf("LEM Worker starting")
log.Printf(" ID: %s", cfg.workerID)
log.Printf(" Name: %s", cfg.name)
log.Printf(" API: %s", cfg.apiBase)
log.Printf(" Infer: %s", cfg.inferURL)
log.Printf(" GPU: %s (%d GB)", cfg.gpuType, cfg.vramGb)
log.Printf(" Langs: %v", cfg.languages)
log.Printf(" Models: %v", cfg.models)
log.Printf(" Batch: %d", cfg.batchSize)
log.Printf(" Dry-run: %v", cfg.dryRun)
// Register with the API.
if err := workerRegister(&cfg); err != nil {
return fmt.Errorf("registration failed: %w", err)
}
log.Println("Registered with LEM API")
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
defer stop()
ticker := time.NewTicker(cfg.pollInterval)
defer ticker.Stop()
// Main loop.
for {
processed := workerPoll(&cfg)
if cfg.oneShot {
log.Printf("One-shot mode: processed %d tasks, exiting", processed)
return nil
}
if processed == 0 {
log.Printf("No tasks available, sleeping %v", cfg.pollInterval)
select {
case <-ctx.Done():
log.Println("Worker shutting down...")
return nil
case <-ticker.C:
}
} else {
// Check context between tasks if needed.
select {
case <-ctx.Done():
log.Println("Worker shutting down...")
return nil
default:
}
}
// Heartbeat.
workerHeartbeat(&cfg)
}
}
func workerRegister(cfg *workerConfig) error {
body := map[string]any{
"worker_id": cfg.workerID,
"name": cfg.name,
"version": "0.1.0",
"os": runtime.GOOS,
"arch": runtime.GOARCH,
}
if cfg.gpuType != "" {
body["gpu_type"] = cfg.gpuType
}
if cfg.vramGb > 0 {
body["vram_gb"] = cfg.vramGb
}
if len(cfg.languages) > 0 {
body["languages"] = cfg.languages
}
if len(cfg.models) > 0 {
body["supported_models"] = cfg.models
}
_, err := apiPost(cfg, "/api/lem/workers/register", body)
return err
}
func workerHeartbeat(cfg *workerConfig) {
body := map[string]any{
"worker_id": cfg.workerID,
}
apiPost(cfg, "/api/lem/workers/heartbeat", body)
}
func workerPoll(cfg *workerConfig) int {
// Fetch available tasks.
url := fmt.Sprintf("/api/lem/tasks/next?worker_id=%s&limit=%d", cfg.workerID, cfg.batchSize)
if cfg.taskType != "" {
url += "&type=" + cfg.taskType
}
resp, err := apiGet(cfg, url)
if err != nil {
log.Printf("Error fetching tasks: %v", err)
return 0
}
var result struct {
Tasks []apiTask `json:"tasks"`
Count int `json:"count"`
}
if err := json.Unmarshal(resp, &result); err != nil {
log.Printf("Error parsing tasks: %v", err)
return 0
}
if result.Count == 0 {
return 0
}
log.Printf("Got %d tasks", result.Count)
processed := 0
for _, task := range result.Tasks {
if err := workerProcessTask(cfg, task); err != nil {
log.Printf("Task %d failed: %v", task.ID, err)
// Release the claim so someone else can try.
apiDelete(cfg, fmt.Sprintf("/api/lem/tasks/%d/claim", task.ID), map[string]any{
"worker_id": cfg.workerID,
})
continue
}
processed++
}
return processed
}
func workerProcessTask(cfg *workerConfig, task apiTask) error {
log.Printf("Processing task %d: %s [%s/%s] %d chars prompt",
task.ID, task.TaskType, task.Language, task.Domain, len(task.PromptText))
// Claim the task.
_, err := apiPost(cfg, fmt.Sprintf("/api/lem/tasks/%d/claim", task.ID), map[string]any{
"worker_id": cfg.workerID,
})
if err != nil {
return fmt.Errorf("claim: %w", err)
}
// Update to in_progress.
apiPatch(cfg, fmt.Sprintf("/api/lem/tasks/%d/status", task.ID), map[string]any{
"worker_id": cfg.workerID,
"status": "in_progress",
})
if cfg.dryRun {
log.Printf(" [DRY-RUN] Would generate response for: %.80s...", task.PromptText)
return nil
}
// Run inference via local API.
start := time.Now()
response, err := workerInfer(cfg, task)
genTime := time.Since(start)
if err != nil {
// Report failure, release task.
apiPatch(cfg, fmt.Sprintf("/api/lem/tasks/%d/status", task.ID), map[string]any{
"worker_id": cfg.workerID,
"status": "abandoned",
})
return fmt.Errorf("inference: %w", err)
}
// Submit result.
modelUsed := task.ModelName
if modelUsed == "" {
modelUsed = "default"
}
_, err = apiPost(cfg, fmt.Sprintf("/api/lem/tasks/%d/result", task.ID), map[string]any{
"worker_id": cfg.workerID,
"response_text": response,
"model_used": modelUsed,
"gen_time_ms": int(genTime.Milliseconds()),
})
if err != nil {
return fmt.Errorf("submit result: %w", err)
}
log.Printf(" Completed: %d chars in %v", len(response), genTime.Round(time.Millisecond))
return nil
}
func workerInfer(cfg *workerConfig, task apiTask) (string, error) {
// Build the chat request for the local inference endpoint.
messages := []map[string]string{
{"role": "user", "content": task.PromptText},
}
temp := 0.7
maxTokens := 2048
if task.Config != nil {
if task.Config.Temperature > 0 {
temp = task.Config.Temperature
}
if task.Config.MaxTokens > 0 {
maxTokens = task.Config.MaxTokens
}
}
reqBody := map[string]any{
"model": task.ModelName,
"messages": messages,
"temperature": temp,
"max_tokens": maxTokens,
}
data, err := json.Marshal(reqBody)
if err != nil {
return "", err
}
req, err := http.NewRequest("POST", cfg.inferURL+"/v1/chat/completions", bytes.NewReader(data))
if err != nil {
return "", err
}
req.Header.Set("Content-Type", "application/json")
client := &http.Client{Timeout: 5 * time.Minute}
resp, err := client.Do(req)
if err != nil {
return "", fmt.Errorf("inference request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("read response: %w", err)
}
if resp.StatusCode != 200 {
return "", fmt.Errorf("inference HTTP %d: %s", resp.StatusCode, truncStr(string(body), 200))
}
// Parse OpenAI-compatible response.
var chatResp struct {
Choices []struct {
Message struct {
Content string `json:"content"`
} `json:"message"`
} `json:"choices"`
}
if err := json.Unmarshal(body, &chatResp); err != nil {
return "", fmt.Errorf("parse response: %w", err)
}
if len(chatResp.Choices) == 0 {
return "", fmt.Errorf("no choices in response")
}
content := chatResp.Choices[0].Message.Content
if len(content) < 10 {
return "", fmt.Errorf("response too short: %d chars", len(content))
}
return content, nil
}
// HTTP helpers for the LEM API.
func apiGet(cfg *workerConfig, path string) ([]byte, error) {
req, err := http.NewRequest("GET", cfg.apiBase+path, nil)
if err != nil {
return nil, err
}
req.Header.Set("Authorization", "Bearer "+cfg.apiKey)
client := &http.Client{Timeout: 15 * time.Second}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
if resp.StatusCode >= 400 {
return nil, fmt.Errorf("HTTP %d: %s", resp.StatusCode, truncStr(string(body), 200))
}
return body, nil
}
func apiPost(cfg *workerConfig, path string, data map[string]any) ([]byte, error) {
return apiRequest(cfg, "POST", path, data)
}
func apiPatch(cfg *workerConfig, path string, data map[string]any) ([]byte, error) {
return apiRequest(cfg, "PATCH", path, data)
}
func apiDelete(cfg *workerConfig, path string, data map[string]any) ([]byte, error) {
return apiRequest(cfg, "DELETE", path, data)
}
func apiRequest(cfg *workerConfig, method, path string, data map[string]any) ([]byte, error) {
jsonData, err := json.Marshal(data)
if err != nil {
return nil, err
}
req, err := http.NewRequest(method, cfg.apiBase+path, bytes.NewReader(jsonData))
if err != nil {
return nil, err
}
req.Header.Set("Authorization", "Bearer "+cfg.apiKey)
req.Header.Set("Content-Type", "application/json")
client := &http.Client{Timeout: 15 * time.Second}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
if resp.StatusCode >= 400 {
return nil, fmt.Errorf("HTTP %d: %s", resp.StatusCode, truncStr(string(body), 200))
}
return body, nil
}
// Utility functions.
func machineID() string {
// Try /etc/machine-id (Linux), then hostname fallback.
if data, err := os.ReadFile("/etc/machine-id"); err == nil {
id := string(bytes.TrimSpace(data))
if len(id) > 0 {
return id
}
}
h, _ := os.Hostname()
return h
}
func hostname() string {
h, _ := os.Hostname()
return h
}
func readKeyFile() string {
home, _ := os.UserHomeDir()
path := filepath.Join(home, ".config", "lem", "api_key")
data, err := os.ReadFile(path)
if err != nil {
return ""
}
return string(bytes.TrimSpace(data))
}
func splitComma(s string) []string {
var result []string
for part := range bytes.SplitSeq([]byte(s), []byte(",")) {
trimmed := bytes.TrimSpace(part)
if len(trimmed) > 0 {
result = append(result, string(trimmed))
}
}
return result
}
func truncStr(s string, n int) string {
maxLen := n
if len(s) <= maxLen {
return s
}
return s[:maxLen] + "..."
}