feat: Add API authentication and comprehensive code review fixes

Security:
- Add HTTP Basic/Digest authentication middleware (enable via MINING_API_AUTH env)
- Fix WebSocket origin check with proper URL parsing
- Add max limit (10000) to remote log lines request
- Improve CLI args validation with stricter patterns

Networking:
- Fix WebSocket double-close with sync.Once in PeerConnection
- Add 10s dial timeout for WebSocket connections
- Reset write deadline after failed sends
- Fix handler race in Transport.OnMessage with RWMutex
- Make EventHub.Stop() idempotent, buffer channels to prevent goroutine leaks

Code Simplification:
- Extract AtomicWriteFile helper to reduce duplication across 4 files
- Remove redundant MinerTypeRegistry, use MinerFactory instead
- Register simulated miner in MinerFactory
- Remove dead portToString() code from manager.go

Documentation:
- Add Advanced API Authentication section to FUTURE_IDEAS.md

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
snider 2025-12-31 14:07:26 +00:00
parent fa3047a314
commit c2ff474386
15 changed files with 475 additions and 213 deletions

View file

@ -139,6 +139,47 @@ deploy/
---
## Advanced API Authentication
**Priority:** Medium
**Effort:** Medium
Expand beyond basic/digest auth with more robust authentication options.
### Current Implementation
- HTTP Basic and Digest authentication (implemented)
- Enabled via environment variables: `MINING_API_AUTH`, `MINING_API_USER`, `MINING_API_PASS`
### Future Options
#### JWT Tokens
- Stateless authentication with expiring tokens
- Refresh token support
- Scoped permissions (read-only, admin, etc.)
#### API Keys
- Generate/revoke API keys from dashboard
- Per-key permissions and rate limits
- Key rotation support
#### OAuth2/OIDC Integration
- Support external identity providers (Google, GitHub, Keycloak)
- SSO for enterprise deployments
- Useful for multi-user mining farms
#### mTLS (Mutual TLS)
- Certificate-based client authentication
- Strongest security for production deployments
- No passwords to manage
### Implementation Notes
- Store credentials/keys in encrypted config file
- Add `/api/v1/auth/token` endpoint for JWT issuance
- Consider using `golang-jwt/jwt` for JWT implementation
- Add audit logging for authentication events
---
## Additional Ideas
### GPU Temperature Monitoring

248
pkg/mining/auth.go Normal file
View file

@ -0,0 +1,248 @@
package mining
import (
"crypto/md5"
"crypto/rand"
"crypto/subtle"
"encoding/hex"
"fmt"
"net/http"
"os"
"strings"
"sync"
"time"
"github.com/Snider/Mining/pkg/logging"
"github.com/gin-gonic/gin"
)
// AuthConfig holds authentication configuration
type AuthConfig struct {
// Enabled determines if authentication is required
Enabled bool
// Username for basic/digest auth
Username string
// Password for basic/digest auth
Password string
// Realm for digest auth
Realm string
// NonceExpiry is how long a nonce is valid
NonceExpiry time.Duration
}
// DefaultAuthConfig returns the default auth configuration.
// Auth is disabled by default for local development.
func DefaultAuthConfig() AuthConfig {
return AuthConfig{
Enabled: false,
Username: "",
Password: "",
Realm: "Mining API",
NonceExpiry: 5 * time.Minute,
}
}
// AuthConfigFromEnv creates auth config from environment variables.
// Set MINING_API_AUTH=true to enable, MINING_API_USER and MINING_API_PASS for credentials.
func AuthConfigFromEnv() AuthConfig {
config := DefaultAuthConfig()
if os.Getenv("MINING_API_AUTH") == "true" {
config.Enabled = true
config.Username = os.Getenv("MINING_API_USER")
config.Password = os.Getenv("MINING_API_PASS")
if config.Username == "" || config.Password == "" {
logging.Warn("API auth enabled but credentials not set", logging.Fields{
"hint": "Set MINING_API_USER and MINING_API_PASS environment variables",
})
config.Enabled = false
}
}
if realm := os.Getenv("MINING_API_REALM"); realm != "" {
config.Realm = realm
}
return config
}
// DigestAuth implements HTTP Digest Authentication middleware
type DigestAuth struct {
config AuthConfig
nonces sync.Map // map[string]time.Time for nonce expiry tracking
}
// NewDigestAuth creates a new digest auth middleware
func NewDigestAuth(config AuthConfig) *DigestAuth {
da := &DigestAuth{config: config}
// Start nonce cleanup goroutine
go da.cleanupNonces()
return da
}
// Middleware returns a Gin middleware that enforces digest authentication
func (da *DigestAuth) Middleware() gin.HandlerFunc {
return func(c *gin.Context) {
if !da.config.Enabled {
c.Next()
return
}
authHeader := c.GetHeader("Authorization")
if authHeader == "" {
da.sendChallenge(c)
return
}
// Try digest auth first
if strings.HasPrefix(authHeader, "Digest ") {
if da.validateDigest(c, authHeader) {
c.Next()
return
}
da.sendChallenge(c)
return
}
// Fall back to basic auth
if strings.HasPrefix(authHeader, "Basic ") {
if da.validateBasic(c, authHeader) {
c.Next()
return
}
}
da.sendChallenge(c)
}
}
// sendChallenge sends a 401 response with digest auth challenge
func (da *DigestAuth) sendChallenge(c *gin.Context) {
nonce := da.generateNonce()
da.nonces.Store(nonce, time.Now())
challenge := fmt.Sprintf(
`Digest realm="%s", qop="auth", nonce="%s", opaque="%s"`,
da.config.Realm,
nonce,
da.generateOpaque(),
)
c.Header("WWW-Authenticate", challenge)
c.AbortWithStatusJSON(http.StatusUnauthorized, APIError{
Code: "AUTH_REQUIRED",
Message: "Authentication required",
Suggestion: "Provide valid credentials using Digest or Basic authentication",
})
}
// validateDigest validates a digest auth header
func (da *DigestAuth) validateDigest(c *gin.Context, authHeader string) bool {
params := parseDigestParams(authHeader[7:]) // Skip "Digest "
nonce := params["nonce"]
if nonce == "" {
return false
}
// Check nonce validity
if storedTime, ok := da.nonces.Load(nonce); ok {
if time.Since(storedTime.(time.Time)) > da.config.NonceExpiry {
da.nonces.Delete(nonce)
return false
}
} else {
return false
}
// Validate username
if params["username"] != da.config.Username {
return false
}
// Calculate expected response
ha1 := md5Hash(fmt.Sprintf("%s:%s:%s", da.config.Username, da.config.Realm, da.config.Password))
ha2 := md5Hash(fmt.Sprintf("%s:%s", c.Request.Method, params["uri"]))
var expectedResponse string
if params["qop"] == "auth" {
expectedResponse = md5Hash(fmt.Sprintf("%s:%s:%s:%s:%s:%s",
ha1, nonce, params["nc"], params["cnonce"], params["qop"], ha2))
} else {
expectedResponse = md5Hash(fmt.Sprintf("%s:%s:%s", ha1, nonce, ha2))
}
// Constant-time comparison to prevent timing attacks
return subtle.ConstantTimeCompare([]byte(expectedResponse), []byte(params["response"])) == 1
}
// validateBasic validates a basic auth header
func (da *DigestAuth) validateBasic(c *gin.Context, authHeader string) bool {
// Gin has built-in basic auth, but we do manual validation for consistency
user, pass, ok := c.Request.BasicAuth()
if !ok {
return false
}
// Constant-time comparison to prevent timing attacks
userMatch := subtle.ConstantTimeCompare([]byte(user), []byte(da.config.Username)) == 1
passMatch := subtle.ConstantTimeCompare([]byte(pass), []byte(da.config.Password)) == 1
return userMatch && passMatch
}
// generateNonce creates a cryptographically random nonce
func (da *DigestAuth) generateNonce() string {
b := make([]byte, 16)
rand.Read(b)
return hex.EncodeToString(b)
}
// generateOpaque creates an opaque value
func (da *DigestAuth) generateOpaque() string {
return md5Hash(da.config.Realm)
}
// cleanupNonces removes expired nonces periodically
func (da *DigestAuth) cleanupNonces() {
ticker := time.NewTicker(da.config.NonceExpiry)
defer ticker.Stop()
for range ticker.C {
now := time.Now()
da.nonces.Range(func(key, value interface{}) bool {
if now.Sub(value.(time.Time)) > da.config.NonceExpiry {
da.nonces.Delete(key)
}
return true
})
}
}
// parseDigestParams parses the parameters from a digest auth header
func parseDigestParams(header string) map[string]string {
params := make(map[string]string)
parts := strings.Split(header, ",")
for _, part := range parts {
part = strings.TrimSpace(part)
idx := strings.Index(part, "=")
if idx < 0 {
continue
}
key := strings.TrimSpace(part[:idx])
value := strings.TrimSpace(part[idx+1:])
// Remove quotes
value = strings.Trim(value, `"`)
params[key] = value
}
return params
}
// md5Hash returns the MD5 hash of a string as a hex string
func md5Hash(s string) string {
h := md5.Sum([]byte(s))
return hex.EncodeToString(h[:])
}

View file

@ -103,45 +103,7 @@ func SaveMinersConfig(cfg *MinersConfig) error {
return fmt.Errorf("failed to marshal miners config: %w", err)
}
// Atomic write: write to temp file, then rename
tmpFile, err := os.CreateTemp(dir, "miners-config-*.tmp")
if err != nil {
return fmt.Errorf("failed to create temp file: %w", err)
}
tmpPath := tmpFile.Name()
// Clean up temp file on error
success := false
defer func() {
if !success {
os.Remove(tmpPath)
}
}()
if _, err := tmpFile.Write(data); err != nil {
tmpFile.Close()
return fmt.Errorf("failed to write temp file: %w", err)
}
if err := tmpFile.Sync(); err != nil {
tmpFile.Close()
return fmt.Errorf("failed to sync temp file: %w", err)
}
if err := tmpFile.Close(); err != nil {
return fmt.Errorf("failed to close temp file: %w", err)
}
if err := os.Chmod(tmpPath, 0600); err != nil {
return fmt.Errorf("failed to set temp file permissions: %w", err)
}
if err := os.Rename(tmpPath, configPath); err != nil {
return fmt.Errorf("failed to rename temp file: %w", err)
}
success = true
return nil
return AtomicWriteFile(configPath, data, 0600)
}
// UpdateMinersConfig atomically loads, modifies, and saves the miners config.
@ -192,41 +154,5 @@ func UpdateMinersConfig(fn func(*MinersConfig) error) error {
return fmt.Errorf("failed to marshal miners config: %w", err)
}
tmpFile, err := os.CreateTemp(dir, "miners-config-*.tmp")
if err != nil {
return fmt.Errorf("failed to create temp file: %w", err)
}
tmpPath := tmpFile.Name()
success := false
defer func() {
if !success {
os.Remove(tmpPath)
}
}()
if _, err := tmpFile.Write(newData); err != nil {
tmpFile.Close()
return fmt.Errorf("failed to write temp file: %w", err)
}
if err := tmpFile.Sync(); err != nil {
tmpFile.Close()
return fmt.Errorf("failed to sync temp file: %w", err)
}
if err := tmpFile.Close(); err != nil {
return fmt.Errorf("failed to close temp file: %w", err)
}
if err := os.Chmod(tmpPath, 0600); err != nil {
return fmt.Errorf("failed to set temp file permissions: %w", err)
}
if err := os.Rename(tmpPath, configPath); err != nil {
return fmt.Errorf("failed to rename temp file: %w", err)
}
success = true
return nil
return AtomicWriteFile(configPath, newData, 0600)
}

View file

@ -86,6 +86,9 @@ type EventHub struct {
// Stop signal
stop chan struct{}
// Ensure Stop() is called only once
stopOnce sync.Once
// Connection limits
maxConnections int
@ -109,8 +112,8 @@ func NewEventHubWithOptions(maxConnections int) *EventHub {
return &EventHub{
clients: make(map[*wsClient]bool),
broadcast: make(chan Event, 256),
register: make(chan *wsClient),
unregister: make(chan *wsClient),
register: make(chan *wsClient, 16),
unregister: make(chan *wsClient, 16), // Buffered to prevent goroutine leaks on shutdown
stop: make(chan struct{}),
maxConnections: maxConnections,
}
@ -235,9 +238,11 @@ func (h *EventHub) shouldSendToClient(client *wsClient, event Event) bool {
return client.miners[minerName]
}
// Stop stops the EventHub
// Stop stops the EventHub (safe to call multiple times)
func (h *EventHub) Stop() {
close(h.stop)
h.stopOnce.Do(func() {
close(h.stop)
})
}
// SetStateProvider sets the function that provides current state for new clients

57
pkg/mining/file_utils.go Normal file
View file

@ -0,0 +1,57 @@
package mining
import (
"fmt"
"os"
"path/filepath"
)
// AtomicWriteFile writes data to a file atomically by writing to a temp file
// first, syncing to disk, then renaming to the target path. This prevents
// corruption if the process is interrupted during write.
func AtomicWriteFile(path string, data []byte, perm os.FileMode) error {
dir := filepath.Dir(path)
// Create temp file in the same directory for atomic rename
tmpFile, err := os.CreateTemp(dir, ".tmp-*")
if err != nil {
return fmt.Errorf("failed to create temp file: %w", err)
}
tmpPath := tmpFile.Name()
// Clean up temp file on error
success := false
defer func() {
if !success {
os.Remove(tmpPath)
}
}()
if _, err := tmpFile.Write(data); err != nil {
tmpFile.Close()
return fmt.Errorf("failed to write temp file: %w", err)
}
// Sync to ensure data is flushed to disk before rename
if err := tmpFile.Sync(); err != nil {
tmpFile.Close()
return fmt.Errorf("failed to sync temp file: %w", err)
}
if err := tmpFile.Close(); err != nil {
return fmt.Errorf("failed to close temp file: %w", err)
}
// Set permissions before rename
if err := os.Chmod(tmpPath, perm); err != nil {
return fmt.Errorf("failed to set file permissions: %w", err)
}
// Atomic rename (on POSIX systems)
if err := os.Rename(tmpPath, path); err != nil {
return fmt.Errorf("failed to rename temp file: %w", err)
}
success = true
return nil
}

View file

@ -5,7 +5,6 @@ import (
"fmt"
"net"
"regexp"
"strconv"
"strings"
"sync"
"time"
@ -718,8 +717,3 @@ func (m *Manager) GetAllMinerHistoricalStats() ([]database.HashrateStats, error)
func (m *Manager) IsDatabaseEnabled() bool {
return m.dbEnabled
}
// Helper to convert port to string for net.JoinHostPort
func portToString(port int) string {
return strconv.Itoa(port)
}

View file

@ -37,6 +37,16 @@ func (f *MinerFactory) registerDefaults() {
// TT-Miner (GPU Kawpow, etc.)
f.Register("tt-miner", func() Miner { return NewTTMiner() })
f.RegisterAlias("ttminer", "tt-miner")
// Simulated miner for testing and development
f.Register(MinerTypeSimulated, func() Miner {
return NewSimulatedMiner(SimulatedMinerConfig{
Name: "simulated-miner",
Algorithm: "rx/0",
BaseHashrate: 1000,
Variance: 0.1,
})
})
}
// Register adds a miner constructor to the factory

View file

@ -401,16 +401,20 @@ func (ns *NodeService) handleRemoteStop(c *gin.Context) {
// @Produce json
// @Param peerId path string true "Peer ID"
// @Param miner path string true "Miner Name"
// @Param lines query int false "Number of lines" default(100)
// @Param lines query int false "Number of lines (max 10000)" default(100)
// @Success 200 {array} string
// @Router /remote/{peerId}/logs/{miner} [get]
func (ns *NodeService) handleRemoteLogs(c *gin.Context) {
peerID := c.Param("peerId")
minerName := c.Param("miner")
lines := 100
const maxLines = 10000 // Prevent resource exhaustion
if l := c.Query("lines"); l != "" {
if parsed, err := strconv.Atoi(l); err == nil && parsed > 0 {
lines = parsed
if lines > maxLines {
lines = maxLines
}
}
}

View file

@ -79,49 +79,7 @@ func (pm *ProfileManager) saveProfiles() error {
return err
}
// Atomic write: write to temp file in same directory, then rename
dir := filepath.Dir(pm.configPath)
tmpFile, err := os.CreateTemp(dir, "profiles-*.tmp")
if err != nil {
return fmt.Errorf("failed to create temp file: %w", err)
}
tmpPath := tmpFile.Name()
// Clean up temp file on any error
success := false
defer func() {
if !success {
os.Remove(tmpPath)
}
}()
if _, err := tmpFile.Write(data); err != nil {
tmpFile.Close()
return fmt.Errorf("failed to write temp file: %w", err)
}
// Sync to ensure data is flushed to disk before rename
if err := tmpFile.Sync(); err != nil {
tmpFile.Close()
return fmt.Errorf("failed to sync temp file: %w", err)
}
if err := tmpFile.Close(); err != nil {
return fmt.Errorf("failed to close temp file: %w", err)
}
// Set permissions before rename
if err := os.Chmod(tmpPath, 0600); err != nil {
return fmt.Errorf("failed to set temp file permissions: %w", err)
}
// Atomic rename (on POSIX systems)
if err := os.Rename(tmpPath, pm.configPath); err != nil {
return fmt.Errorf("failed to rename temp file: %w", err)
}
success = true
return nil
return AtomicWriteFile(pm.configPath, data, 0600)
}
// CreateProfile adds a new profile and saves it.

View file

@ -96,45 +96,7 @@ func (r *FileRepository[T]) saveUnlocked(data T) error {
return fmt.Errorf("failed to marshal data: %w", err)
}
// Atomic write: write to temp file, sync, then rename
tmpFile, err := os.CreateTemp(dir, "repo-*.tmp")
if err != nil {
return fmt.Errorf("failed to create temp file: %w", err)
}
tmpPath := tmpFile.Name()
// Clean up temp file on error
success := false
defer func() {
if !success {
os.Remove(tmpPath)
}
}()
if _, err := tmpFile.Write(jsonData); err != nil {
tmpFile.Close()
return fmt.Errorf("failed to write temp file: %w", err)
}
if err := tmpFile.Sync(); err != nil {
tmpFile.Close()
return fmt.Errorf("failed to sync temp file: %w", err)
}
if err := tmpFile.Close(); err != nil {
return fmt.Errorf("failed to close temp file: %w", err)
}
if err := os.Chmod(tmpPath, 0600); err != nil {
return fmt.Errorf("failed to set file permissions: %w", err)
}
if err := os.Rename(tmpPath, r.path); err != nil {
return fmt.Errorf("failed to rename temp file: %w", err)
}
success = true
return nil
return AtomicWriteFile(r.path, jsonData, 0600)
}
// Update atomically loads, modifies, and saves data.

View file

@ -7,6 +7,7 @@ import (
"fmt"
"net"
"net/http"
"net/url"
"os"
"path/filepath"
"runtime"
@ -40,6 +41,7 @@ type Service struct {
APIBasePath string
SwaggerUIPath string
rateLimiter *RateLimiter
auth *DigestAuth
}
// APIError represents a structured error response for the API
@ -146,15 +148,20 @@ var wsUpgrader = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
CheckOrigin: func(r *http.Request) bool {
// Allow connections from localhost origins
// Allow connections from localhost origins only
origin := r.Header.Get("Origin")
if origin == "" {
return true
return true // No origin header (non-browser clients)
}
// Allow localhost with any port
return strings.Contains(origin, "localhost") ||
strings.Contains(origin, "127.0.0.1") ||
strings.Contains(origin, "wails.localhost")
// Parse the origin URL properly to prevent bypass attacks
u, err := url.Parse(origin)
if err != nil {
return false
}
host := u.Hostname()
// Only allow exact localhost matches
return host == "localhost" || host == "127.0.0.1" || host == "::1" ||
host == "wails.localhost"
},
}
@ -218,6 +225,14 @@ func NewService(manager ManagerInterface, listenAddr string, displayAddr string,
}
})
// Initialize authentication from environment
authConfig := AuthConfigFromEnv()
var auth *DigestAuth
if authConfig.Enabled {
auth = NewDigestAuth(authConfig)
logging.Info("API authentication enabled", logging.Fields{"realm": authConfig.Realm})
}
return &Service{
Manager: manager,
ProfileManager: profileManager,
@ -234,6 +249,7 @@ func NewService(manager ManagerInterface, listenAddr string, displayAddr string,
SwaggerInstanceName: instanceName,
APIBasePath: apiBasePath,
SwaggerUIPath: swaggerUIPath,
auth: auth,
}, nil
}
@ -351,6 +367,12 @@ func (s *Service) ServiceStartup(ctx context.Context) error {
// manually after InitRouter for embedding in other applications.
func (s *Service) SetupRoutes() {
apiGroup := s.Router.Group(s.APIBasePath)
// Apply authentication middleware if enabled
if s.auth != nil {
apiGroup.Use(s.auth.Middleware())
}
{
apiGroup.GET("/info", s.handleGetInfo)
apiGroup.POST("/doctor", s.handleDoctor)

View file

@ -56,16 +56,3 @@ func FetchJSONStats[T any](ctx context.Context, config HTTPStatsConfig, target *
return nil
}
// MinerTypeRegistry provides a central registry of known miner types.
// This can be used for validation and discovery of available miners.
var MinerTypeRegistry = map[string]string{
MinerTypeXMRig: "XMRig - CPU/GPU miner for RandomX, KawPow, CryptoNight",
MinerTypeTTMiner: "TT-Miner - NVIDIA GPU miner for Ethash, KawPow, ProgPow",
MinerTypeSimulated: "Simulated - Mock miner for testing and development",
}
// IsKnownMinerType returns true if the given type is a registered miner type.
func IsKnownMinerType(minerType string) bool {
_, exists := MinerTypeRegistry[minerType]
return exists
}

View file

@ -87,28 +87,27 @@ func TestFetchJSONStats(t *testing.T) {
func TestMinerTypeRegistry(t *testing.T) {
t.Run("KnownTypes", func(t *testing.T) {
if !IsKnownMinerType(MinerTypeXMRig) {
if !IsMinerSupported(MinerTypeXMRig) {
t.Error("xmrig should be a known miner type")
}
if !IsKnownMinerType(MinerTypeTTMiner) {
if !IsMinerSupported(MinerTypeTTMiner) {
t.Error("tt-miner should be a known miner type")
}
if !IsKnownMinerType(MinerTypeSimulated) {
if !IsMinerSupported(MinerTypeSimulated) {
t.Error("simulated should be a known miner type")
}
})
t.Run("UnknownType", func(t *testing.T) {
if IsKnownMinerType("unknown-miner") {
if IsMinerSupported("unknown-miner") {
t.Error("unknown-miner should not be a known miner type")
}
})
t.Run("RegistryHasDescriptions", func(t *testing.T) {
for minerType, description := range MinerTypeRegistry {
if description == "" {
t.Errorf("Miner type %s has empty description", minerType)
}
t.Run("ListMinerTypes", func(t *testing.T) {
types := ListMinerTypes()
if len(types) == 0 {
t.Error("ListMinerTypes should return registered types")
}
})
}

View file

@ -173,22 +173,56 @@ func addTTMinerCliArgs(config *Config, args *[]string) {
}
}
// isValidCLIArg validates CLI arguments to prevent injection or dangerous patterns
// isValidCLIArg validates CLI arguments to prevent injection or dangerous patterns.
// Uses a combination of allowlist patterns and blocklist for security.
func isValidCLIArg(arg string) bool {
// Empty or whitespace-only args are invalid
if strings.TrimSpace(arg) == "" {
return false
}
// Must start with dash (standard CLI argument format)
// This is an allowlist approach - only accept valid argument patterns
if !strings.HasPrefix(arg, "-") {
// Allow values for flags (e.g., the "3" in "-i 3")
// Values must not contain shell metacharacters
return isValidArgValue(arg)
}
// Block shell metacharacters and dangerous patterns
dangerousPatterns := []string{";", "|", "&", "`", "$", "(", ")", "{", "}", "<", ">", "\n", "\r"}
if !isValidArgValue(arg) {
return false
}
// Block arguments that could override security-related settings
blockedPrefixes := []string{
"--api-access-token", "--api-worker-id", // TT-Miner API settings
"--config", // Could load arbitrary config
"--log-file", // Could write to arbitrary locations
"--coin-file", // Could load arbitrary coin configs
"-o", "--out", // Output redirection
}
lowerArg := strings.ToLower(arg)
for _, blocked := range blockedPrefixes {
if lowerArg == blocked || strings.HasPrefix(lowerArg, blocked+"=") {
return false
}
}
return true
}
// isValidArgValue checks if a value contains dangerous patterns
func isValidArgValue(arg string) bool {
// Block shell metacharacters and command injection patterns
dangerousPatterns := []string{
";", "|", "&", "`", "$", "(", ")", "{", "}",
"<", ">", "\n", "\r", "\\", "'", "\"", "!",
}
for _, p := range dangerousPatterns {
if strings.Contains(arg, p) {
return false
}
}
// Block arguments that could override security-related settings
blockedArgs := []string{"--api-access-token", "--api-worker-id"}
lowerArg := strings.ToLower(arg)
for _, blocked := range blockedArgs {
if strings.HasPrefix(lowerArg, blocked) {
return false
}
}
return true
}

View file

@ -63,6 +63,7 @@ type PeerConnection struct {
LastActivity time.Time
writeMu sync.Mutex // Serialize WebSocket writes
transport *Transport
closeOnce sync.Once // Ensure Close() is only called once
}
// NewTransport creates a new WebSocket transport.
@ -150,7 +151,10 @@ func (t *Transport) Stop() error {
}
// OnMessage sets the handler for incoming messages.
// Must be called before Start() to avoid races.
func (t *Transport) OnMessage(handler MessageHandler) {
t.mu.Lock()
defer t.mu.Unlock()
t.handler = handler
}
@ -163,8 +167,11 @@ func (t *Transport) Connect(peer *Peer) (*PeerConnection, error) {
}
u := url.URL{Scheme: scheme, Host: peer.Address, Path: t.config.WSPath}
// Dial the peer
conn, _, err := websocket.DefaultDialer.Dial(u.String(), nil)
// Dial the peer with timeout to prevent hanging on unresponsive peers
dialer := websocket.Dialer{
HandshakeTimeout: 10 * time.Second,
}
conn, _, err := dialer.Dial(u.String(), nil)
if err != nil {
return nil, fmt.Errorf("failed to connect to peer: %w", err)
}
@ -485,9 +492,12 @@ func (t *Transport) readLoop(pc *PeerConnection) {
logging.Debug("received message from peer", logging.Fields{"type": msg.Type, "peer_id": pc.Peer.ID, "reply_to": msg.ReplyTo})
// Dispatch to handler
if t.handler != nil {
t.handler(pc, msg)
// Dispatch to handler (read handler under lock to avoid race)
t.mu.RLock()
handler := t.handler
t.mu.RUnlock()
if handler != nil {
handler(pc, msg)
}
}
}
@ -552,13 +562,18 @@ func (pc *PeerConnection) Send(msg *Message) error {
if err := pc.Conn.SetWriteDeadline(time.Now().Add(10 * time.Second)); err != nil {
return fmt.Errorf("failed to set write deadline: %w", err)
}
defer pc.Conn.SetWriteDeadline(time.Time{}) // Reset deadline after send
return pc.Conn.WriteMessage(websocket.BinaryMessage, data)
}
// Close closes the connection.
func (pc *PeerConnection) Close() error {
return pc.Conn.Close()
var err error
pc.closeOnce.Do(func() {
err = pc.Conn.Close()
})
return err
}
// encryptMessage encrypts a message using SMSG with the shared secret.