diff --git a/main.go b/main.go index c57e0ba..9011398 100644 --- a/main.go +++ b/main.go @@ -42,6 +42,9 @@ Monitoring: coverage Analyze seed coverage gaps metrics Push DuckDB golden set stats to InfluxDB +Distributed: + worker Run as distributed inference worker node + Infrastructure: ingest Ingest benchmark data into InfluxDB seed-influx Seed InfluxDB golden_gen from DuckDB @@ -101,6 +104,8 @@ func main() { lem.RunQuery(os.Args[2:]) case "agent": lem.RunAgent(os.Args[2:]) + case "worker": + lem.RunWorker(os.Args[2:]) default: fmt.Fprintf(os.Stderr, "unknown command: %s\n\n%s", os.Args[1], usage) os.Exit(1) diff --git a/pkg/lem/worker.go b/pkg/lem/worker.go new file mode 100644 index 0000000..2a1cbff --- /dev/null +++ b/pkg/lem/worker.go @@ -0,0 +1,454 @@ +package lem + +import ( + "bytes" + "encoding/json" + "flag" + "fmt" + "io" + "log" + "net/http" + "os" + "path/filepath" + "runtime" + "time" +) + +// workerConfig holds the worker's runtime configuration. +type workerConfig struct { + apiBase string + workerID string + name string + apiKey string + gpuType string + vramGb int + languages []string + models []string + inferURL string + taskType string + batchSize int + pollInterval time.Duration + oneShot bool + dryRun bool +} + +// apiTask represents a task from the LEM API. +type apiTask struct { + ID int `json:"id"` + TaskType string `json:"task_type"` + Status string `json:"status"` + Language string `json:"language"` + Domain string `json:"domain"` + ModelName string `json:"model_name"` + PromptID string `json:"prompt_id"` + PromptText string `json:"prompt_text"` + Config *struct { + Temperature float64 `json:"temperature,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + } `json:"config"` + Priority int `json:"priority"` +} + +// RunWorker is the CLI entry point for `lem worker`. +func RunWorker(args []string) { + fs := flag.NewFlagSet("worker", flag.ExitOnError) + + cfg := workerConfig{} + var langs, mods string + + fs.StringVar(&cfg.apiBase, "api", envOr("LEM_API", "https://infer.lthn.ai"), "LEM API base URL") + fs.StringVar(&cfg.workerID, "id", envOr("LEM_WORKER_ID", machineID()), "Worker ID (machine UUID)") + fs.StringVar(&cfg.name, "name", envOr("LEM_WORKER_NAME", hostname()), "Worker display name") + fs.StringVar(&cfg.apiKey, "key", envOr("LEM_API_KEY", ""), "API key (or use LEM_API_KEY env)") + fs.StringVar(&cfg.gpuType, "gpu", envOr("LEM_GPU", ""), "GPU type (e.g. 'RTX 3090')") + fs.IntVar(&cfg.vramGb, "vram", intEnvOr("LEM_VRAM_GB", 0), "GPU VRAM in GB") + fs.StringVar(&langs, "languages", envOr("LEM_LANGUAGES", ""), "Comma-separated language codes (e.g. 'en,yo,sw')") + fs.StringVar(&mods, "models", envOr("LEM_MODELS", ""), "Comma-separated supported model names") + fs.StringVar(&cfg.inferURL, "infer", envOr("LEM_INFER_URL", "http://localhost:8090"), "Local inference endpoint") + fs.StringVar(&cfg.taskType, "type", "", "Filter by task type (expand, score, translate, seed)") + fs.IntVar(&cfg.batchSize, "batch", 5, "Number of tasks to fetch per poll") + dur := fs.Duration("poll", 30*time.Second, "Poll interval") + fs.BoolVar(&cfg.oneShot, "one-shot", false, "Process one batch and exit") + fs.BoolVar(&cfg.dryRun, "dry-run", false, "Fetch tasks but don't run inference") + + fs.Parse(args) + cfg.pollInterval = *dur + + if langs != "" { + cfg.languages = splitComma(langs) + } + if mods != "" { + cfg.models = splitComma(mods) + } + + if cfg.apiKey == "" { + cfg.apiKey = readKeyFile() + } + + if cfg.apiKey == "" { + log.Fatal("No API key. Set LEM_API_KEY or run: lem worker-auth") + } + + log.Printf("LEM Worker starting") + log.Printf(" ID: %s", cfg.workerID) + log.Printf(" Name: %s", cfg.name) + log.Printf(" API: %s", cfg.apiBase) + log.Printf(" Infer: %s", cfg.inferURL) + log.Printf(" GPU: %s (%d GB)", cfg.gpuType, cfg.vramGb) + log.Printf(" Langs: %v", cfg.languages) + log.Printf(" Models: %v", cfg.models) + log.Printf(" Batch: %d", cfg.batchSize) + log.Printf(" Dry-run: %v", cfg.dryRun) + + // Register with the API. + if err := workerRegister(&cfg); err != nil { + log.Fatalf("Registration failed: %v", err) + } + log.Println("Registered with LEM API") + + // Main loop. + for { + processed := workerPoll(&cfg) + + if cfg.oneShot { + log.Printf("One-shot mode: processed %d tasks, exiting", processed) + return + } + + if processed == 0 { + log.Printf("No tasks available, sleeping %v", cfg.pollInterval) + time.Sleep(cfg.pollInterval) + } + + // Heartbeat. + workerHeartbeat(&cfg) + } +} + +func workerRegister(cfg *workerConfig) error { + body := map[string]interface{}{ + "worker_id": cfg.workerID, + "name": cfg.name, + "version": "0.1.0", + "os": runtime.GOOS, + "arch": runtime.GOARCH, + } + if cfg.gpuType != "" { + body["gpu_type"] = cfg.gpuType + } + if cfg.vramGb > 0 { + body["vram_gb"] = cfg.vramGb + } + if len(cfg.languages) > 0 { + body["languages"] = cfg.languages + } + if len(cfg.models) > 0 { + body["supported_models"] = cfg.models + } + + _, err := apiPost(cfg, "/api/lem/workers/register", body) + return err +} + +func workerHeartbeat(cfg *workerConfig) { + body := map[string]interface{}{ + "worker_id": cfg.workerID, + } + apiPost(cfg, "/api/lem/workers/heartbeat", body) +} + +func workerPoll(cfg *workerConfig) int { + // Fetch available tasks. + url := fmt.Sprintf("/api/lem/tasks/next?worker_id=%s&limit=%d", cfg.workerID, cfg.batchSize) + if cfg.taskType != "" { + url += "&type=" + cfg.taskType + } + + resp, err := apiGet(cfg, url) + if err != nil { + log.Printf("Error fetching tasks: %v", err) + return 0 + } + + var result struct { + Tasks []apiTask `json:"tasks"` + Count int `json:"count"` + } + if err := json.Unmarshal(resp, &result); err != nil { + log.Printf("Error parsing tasks: %v", err) + return 0 + } + + if result.Count == 0 { + return 0 + } + + log.Printf("Got %d tasks", result.Count) + processed := 0 + + for _, task := range result.Tasks { + if err := workerProcessTask(cfg, task); err != nil { + log.Printf("Task %d failed: %v", task.ID, err) + // Release the claim so someone else can try. + apiDelete(cfg, fmt.Sprintf("/api/lem/tasks/%d/claim", task.ID), map[string]interface{}{ + "worker_id": cfg.workerID, + }) + continue + } + processed++ + } + + return processed +} + +func workerProcessTask(cfg *workerConfig, task apiTask) error { + log.Printf("Processing task %d: %s [%s/%s] %d chars prompt", + task.ID, task.TaskType, task.Language, task.Domain, len(task.PromptText)) + + // Claim the task. + _, err := apiPost(cfg, fmt.Sprintf("/api/lem/tasks/%d/claim", task.ID), map[string]interface{}{ + "worker_id": cfg.workerID, + }) + if err != nil { + return fmt.Errorf("claim: %w", err) + } + + // Update to in_progress. + apiPatch(cfg, fmt.Sprintf("/api/lem/tasks/%d/status", task.ID), map[string]interface{}{ + "worker_id": cfg.workerID, + "status": "in_progress", + }) + + if cfg.dryRun { + log.Printf(" [DRY-RUN] Would generate response for: %.80s...", task.PromptText) + return nil + } + + // Run inference via local API. + start := time.Now() + response, err := workerInfer(cfg, task) + genTime := time.Since(start) + + if err != nil { + // Report failure, release task. + apiPatch(cfg, fmt.Sprintf("/api/lem/tasks/%d/status", task.ID), map[string]interface{}{ + "worker_id": cfg.workerID, + "status": "abandoned", + }) + return fmt.Errorf("inference: %w", err) + } + + // Submit result. + modelUsed := task.ModelName + if modelUsed == "" { + modelUsed = "default" + } + + _, err = apiPost(cfg, fmt.Sprintf("/api/lem/tasks/%d/result", task.ID), map[string]interface{}{ + "worker_id": cfg.workerID, + "response_text": response, + "model_used": modelUsed, + "gen_time_ms": int(genTime.Milliseconds()), + }) + if err != nil { + return fmt.Errorf("submit result: %w", err) + } + + log.Printf(" Completed: %d chars in %v", len(response), genTime.Round(time.Millisecond)) + return nil +} + +func workerInfer(cfg *workerConfig, task apiTask) (string, error) { + // Build the chat request for the local inference endpoint. + messages := []map[string]string{ + {"role": "user", "content": task.PromptText}, + } + + temp := 0.7 + maxTokens := 2048 + if task.Config != nil { + if task.Config.Temperature > 0 { + temp = task.Config.Temperature + } + if task.Config.MaxTokens > 0 { + maxTokens = task.Config.MaxTokens + } + } + + reqBody := map[string]interface{}{ + "model": task.ModelName, + "messages": messages, + "temperature": temp, + "max_tokens": maxTokens, + } + + data, err := json.Marshal(reqBody) + if err != nil { + return "", err + } + + req, err := http.NewRequest("POST", cfg.inferURL+"/v1/chat/completions", bytes.NewReader(data)) + if err != nil { + return "", err + } + req.Header.Set("Content-Type", "application/json") + + client := &http.Client{Timeout: 5 * time.Minute} + resp, err := client.Do(req) + if err != nil { + return "", fmt.Errorf("inference request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return "", fmt.Errorf("read response: %w", err) + } + + if resp.StatusCode != 200 { + return "", fmt.Errorf("inference HTTP %d: %s", resp.StatusCode, truncStr(string(body), 200)) + } + + // Parse OpenAI-compatible response. + var chatResp struct { + Choices []struct { + Message struct { + Content string `json:"content"` + } `json:"message"` + } `json:"choices"` + } + if err := json.Unmarshal(body, &chatResp); err != nil { + return "", fmt.Errorf("parse response: %w", err) + } + + if len(chatResp.Choices) == 0 { + return "", fmt.Errorf("no choices in response") + } + + content := chatResp.Choices[0].Message.Content + if len(content) < 10 { + return "", fmt.Errorf("response too short: %d chars", len(content)) + } + + return content, nil +} + +// HTTP helpers for the LEM API. + +func apiGet(cfg *workerConfig, path string) ([]byte, error) { + req, err := http.NewRequest("GET", cfg.apiBase+path, nil) + if err != nil { + return nil, err + } + req.Header.Set("Authorization", "Bearer "+cfg.apiKey) + + client := &http.Client{Timeout: 15 * time.Second} + resp, err := client.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + if resp.StatusCode >= 400 { + return nil, fmt.Errorf("HTTP %d: %s", resp.StatusCode, truncStr(string(body), 200)) + } + + return body, nil +} + +func apiPost(cfg *workerConfig, path string, data map[string]interface{}) ([]byte, error) { + return apiRequest(cfg, "POST", path, data) +} + +func apiPatch(cfg *workerConfig, path string, data map[string]interface{}) ([]byte, error) { + return apiRequest(cfg, "PATCH", path, data) +} + +func apiDelete(cfg *workerConfig, path string, data map[string]interface{}) ([]byte, error) { + return apiRequest(cfg, "DELETE", path, data) +} + +func apiRequest(cfg *workerConfig, method, path string, data map[string]interface{}) ([]byte, error) { + jsonData, err := json.Marshal(data) + if err != nil { + return nil, err + } + + req, err := http.NewRequest(method, cfg.apiBase+path, bytes.NewReader(jsonData)) + if err != nil { + return nil, err + } + req.Header.Set("Authorization", "Bearer "+cfg.apiKey) + req.Header.Set("Content-Type", "application/json") + + client := &http.Client{Timeout: 15 * time.Second} + resp, err := client.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + if resp.StatusCode >= 400 { + return nil, fmt.Errorf("HTTP %d: %s", resp.StatusCode, truncStr(string(body), 200)) + } + + return body, nil +} + +// Utility functions. + +func machineID() string { + // Try /etc/machine-id (Linux), then hostname fallback. + if data, err := os.ReadFile("/etc/machine-id"); err == nil { + id := string(bytes.TrimSpace(data)) + if len(id) > 0 { + return id + } + } + h, _ := os.Hostname() + return h +} + +func hostname() string { + h, _ := os.Hostname() + return h +} + +func readKeyFile() string { + home, _ := os.UserHomeDir() + path := filepath.Join(home, ".config", "lem", "api_key") + data, err := os.ReadFile(path) + if err != nil { + return "" + } + return string(bytes.TrimSpace(data)) +} + +func splitComma(s string) []string { + var result []string + for _, part := range bytes.Split([]byte(s), []byte(",")) { + trimmed := bytes.TrimSpace(part) + if len(trimmed) > 0 { + result = append(result, string(trimmed)) + } + } + return result +} + +func truncStr(s string, n int) string { + maxLen := n + if len(s) <= maxLen { + return s + } + return s[:maxLen] + "..." +} diff --git a/pkg/lem/worker_test.go b/pkg/lem/worker_test.go new file mode 100644 index 0000000..7923ed3 --- /dev/null +++ b/pkg/lem/worker_test.go @@ -0,0 +1,197 @@ +package lem + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" +) + +func TestMachineID(t *testing.T) { + id := machineID() + if id == "" { + t.Error("machineID returned empty string") + } +} + +func TestHostname(t *testing.T) { + h := hostname() + if h == "" { + t.Error("hostname returned empty string") + } +} + +func TestSplitComma(t *testing.T) { + tests := []struct { + input string + want int + }{ + {"en,yo,sw", 3}, + {"en", 1}, + {"", 0}, + {"en, yo , sw", 3}, + } + for _, tt := range tests { + got := splitComma(tt.input) + if len(got) != tt.want { + t.Errorf("splitComma(%q) = %d items, want %d", tt.input, len(got), tt.want) + } + } +} + +func TestTruncStr(t *testing.T) { + if got := truncStr("hello", 10); got != "hello" { + t.Errorf("truncStr short = %q", got) + } + if got := truncStr("hello world", 5); got != "hello..." { + t.Errorf("truncStr long = %q", got) + } +} + +func TestWorkerRegister(t *testing.T) { + var gotBody map[string]interface{} + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/api/lem/workers/register" { + t.Errorf("unexpected path: %s", r.URL.Path) + } + if r.Header.Get("Authorization") != "Bearer test-key" { + t.Errorf("missing auth header") + } + json.NewDecoder(r.Body).Decode(&gotBody) + w.WriteHeader(201) + w.Write([]byte(`{"worker":{}}`)) + })) + defer srv.Close() + + cfg := &workerConfig{ + apiBase: srv.URL, + workerID: "test-worker-001", + name: "Test Worker", + apiKey: "test-key", + gpuType: "RTX 3090", + vramGb: 24, + languages: []string{"en", "yo"}, + models: []string{"gemma-3-12b"}, + } + + err := workerRegister(cfg) + if err != nil { + t.Fatalf("register failed: %v", err) + } + + if gotBody["worker_id"] != "test-worker-001" { + t.Errorf("worker_id = %v", gotBody["worker_id"]) + } + if gotBody["gpu_type"] != "RTX 3090" { + t.Errorf("gpu_type = %v", gotBody["gpu_type"]) + } +} + +func TestWorkerPoll(t *testing.T) { + callCount := 0 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + callCount++ + switch { + case r.URL.Path == "/api/lem/tasks/next": + // Return one task. + json.NewEncoder(w).Encode(map[string]interface{}{ + "tasks": []apiTask{ + { + ID: 42, + TaskType: "expand", + Language: "yo", + Domain: "sovereignty", + PromptText: "What does sovereignty mean to you?", + ModelName: "gemma-3-12b", + Priority: 100, + }, + }, + "count": 1, + }) + case r.URL.Path == "/api/lem/tasks/42/claim" && r.Method == "POST": + w.WriteHeader(201) + w.Write([]byte(`{"task":{}}`)) + case r.URL.Path == "/api/lem/tasks/42/status" && r.Method == "PATCH": + w.Write([]byte(`{"task":{}}`)) + default: + w.WriteHeader(404) + } + })) + defer srv.Close() + + cfg := &workerConfig{ + apiBase: srv.URL, + workerID: "test-worker", + apiKey: "test-key", + batchSize: 5, + dryRun: true, // don't actually run inference + } + + processed := workerPoll(cfg) + if processed != 1 { + t.Errorf("processed = %d, want 1", processed) + } +} + +func TestWorkerInfer(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/chat/completions" { + t.Errorf("unexpected path: %s", r.URL.Path) + } + + var body map[string]interface{} + json.NewDecoder(r.Body).Decode(&body) + + if body["temperature"].(float64) != 0.7 { + t.Errorf("temperature = %v", body["temperature"]) + } + + json.NewEncoder(w).Encode(map[string]interface{}{ + "choices": []map[string]interface{}{ + { + "message": map[string]string{ + "content": "Sovereignty means the inherent right of every individual to self-determination...", + }, + }, + }, + }) + })) + defer srv.Close() + + cfg := &workerConfig{ + inferURL: srv.URL, + } + + task := apiTask{ + ID: 1, + PromptText: "What does sovereignty mean?", + ModelName: "gemma-3-12b", + } + + response, err := workerInfer(cfg, task) + if err != nil { + t.Fatalf("infer failed: %v", err) + } + + if len(response) < 10 { + t.Errorf("response too short: %d chars", len(response)) + } +} + +func TestAPIErrorHandling(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(500) + w.Write([]byte(`{"error":"internal server error"}`)) + })) + defer srv.Close() + + cfg := &workerConfig{ + apiBase: srv.URL, + apiKey: "test-key", + } + + _, err := apiGet(cfg, "/api/lem/test") + if err == nil { + t.Error("expected error for 500 response") + } +}