ide/claude_bridge.go
2026-02-16 13:55:59 +00:00

183 lines
4.4 KiB
Go

package main
import (
"encoding/json"
"log"
"net/http"
"strings"
"sync"
"time"
"github.com/gorilla/websocket"
)
var wsUpgrader = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
CheckOrigin: func(r *http.Request) bool {
origin := r.Header.Get("Origin")
if origin == "" {
return true // Allow requests with no Origin header (same-origin)
}
host := r.Host
return origin == "http://"+host || origin == "https://"+host ||
strings.HasPrefix(origin, "http://localhost") || strings.HasPrefix(origin, "http://127.0.0.1")
},
}
// ClaudeBridge forwards messages between GUI clients and the MCP core WebSocket.
type ClaudeBridge struct {
mcpConn *websocket.Conn
mcpURL string
clients map[*websocket.Conn]bool
clientsMu sync.RWMutex
broadcast chan []byte
reconnectMu sync.Mutex
}
// NewClaudeBridge creates a new bridge to the MCP core WebSocket.
func NewClaudeBridge(mcpURL string) *ClaudeBridge {
return &ClaudeBridge{
mcpURL: mcpURL,
clients: make(map[*websocket.Conn]bool),
broadcast: make(chan []byte, 256),
}
}
// Start connects to the MCP WebSocket and starts the bridge.
func (cb *ClaudeBridge) Start() {
go cb.connectToMCP()
go cb.broadcastLoop()
}
// connectToMCP establishes connection to the MCP core WebSocket.
func (cb *ClaudeBridge) connectToMCP() {
for {
cb.reconnectMu.Lock()
if cb.mcpConn != nil {
cb.mcpConn.Close()
}
log.Printf("Claude bridge connecting to MCP at %s", cb.mcpURL)
conn, _, err := websocket.DefaultDialer.Dial(cb.mcpURL, nil)
if err != nil {
log.Printf("Claude bridge failed to connect to MCP: %v", err)
cb.reconnectMu.Unlock()
time.Sleep(5 * time.Second)
continue
}
cb.mcpConn = conn
cb.reconnectMu.Unlock()
log.Printf("Claude bridge connected to MCP")
// Read messages from MCP and broadcast to clients
for {
_, message, err := conn.ReadMessage()
if err != nil {
log.Printf("Claude bridge MCP read error: %v", err)
break
}
select {
case cb.broadcast <- message:
default:
log.Printf("Claude bridge: broadcast channel full, dropping message")
}
}
// Connection lost, retry
time.Sleep(2 * time.Second)
}
}
// broadcastLoop sends messages from MCP to all connected clients.
func (cb *ClaudeBridge) broadcastLoop() {
for message := range cb.broadcast {
var failedClients []*websocket.Conn
cb.clientsMu.RLock()
for client := range cb.clients {
err := client.WriteMessage(websocket.TextMessage, message)
if err != nil {
log.Printf("Claude bridge client write error: %v", err)
failedClients = append(failedClients, client)
}
}
cb.clientsMu.RUnlock()
if len(failedClients) > 0 {
cb.clientsMu.Lock()
for _, client := range failedClients {
delete(cb.clients, client)
client.Close()
}
cb.clientsMu.Unlock()
}
}
}
// HandleWebSocket handles WebSocket connections from GUI clients.
func (cb *ClaudeBridge) HandleWebSocket(w http.ResponseWriter, r *http.Request) {
conn, err := wsUpgrader.Upgrade(w, r, nil)
if err != nil {
log.Printf("Claude bridge upgrade error: %v", err)
return
}
// Send connected message before registering to avoid concurrent writes
connMsg, _ := json.Marshal(map[string]any{
"type": "system",
"data": "Connected to Claude bridge",
"timestamp": time.Now(),
})
if err := conn.WriteMessage(websocket.TextMessage, connMsg); err != nil {
log.Printf("Claude bridge initial write error: %v", err)
conn.Close()
return
}
cb.clientsMu.Lock()
cb.clients[conn] = true
cb.clientsMu.Unlock()
defer func() {
cb.clientsMu.Lock()
delete(cb.clients, conn)
cb.clientsMu.Unlock()
conn.Close()
}()
// Read messages from client and forward to MCP
for {
_, message, err := conn.ReadMessage()
if err != nil {
break
}
// Parse the message to check type
var msg map[string]any
if err := json.Unmarshal(message, &msg); err != nil {
continue
}
// Forward claude_message to MCP
if msgType, ok := msg["type"].(string); ok && msgType == "claude_message" {
cb.sendToMCP(message)
}
}
}
// sendToMCP sends a message to the MCP WebSocket.
func (cb *ClaudeBridge) sendToMCP(message []byte) {
cb.reconnectMu.Lock()
defer cb.reconnectMu.Unlock()
if cb.mcpConn == nil {
log.Printf("Claude bridge: MCP not connected")
return
}
err := cb.mcpConn.WriteMessage(websocket.TextMessage, message)
if err != nil {
log.Printf("Claude bridge MCP write error: %v", err)
}
}