From 0bb5ce827b1cebb7b4be702e861e73c89e321c21 Mon Sep 17 00:00:00 2001 From: Virgil Date: Sat, 4 Apr 2026 23:07:43 +0000 Subject: [PATCH] fix(proxy): fail fast on HTTP bind errors Co-Authored-By: Virgil --- http_auth_test.go | 52 +++++++++++++++++++++++++++++++++++++++++++++++ state_impl.go | 20 +++++++++++++++--- 2 files changed, 69 insertions(+), 3 deletions(-) diff --git a/http_auth_test.go b/http_auth_test.go index 8fed0fe..cf23b59 100644 --- a/http_auth_test.go +++ b/http_auth_test.go @@ -1,7 +1,9 @@ package proxy import ( + "net" "net/http" + "strconv" "testing" ) @@ -69,3 +71,53 @@ func TestProxy_allowHTTP_Ugly(t *testing.T) { t.Fatalf("expected status %d, got %d", http.StatusUnauthorized, status) } } + +func TestProxy_startHTTP_Good(t *testing.T) { + p := &Proxy{ + config: &Config{ + HTTP: HTTPConfig{ + Enabled: true, + Host: "127.0.0.1", + Port: 0, + }, + }, + done: make(chan struct{}), + } + + if ok := p.startHTTP(); !ok { + t.Fatal("expected HTTP server to start on a free port") + } + p.Stop() +} + +func TestProxy_startHTTP_Bad(t *testing.T) { + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen on ephemeral port: %v", err) + } + defer listener.Close() + + host, port, err := net.SplitHostPort(listener.Addr().String()) + if err != nil { + t.Fatalf("split listener addr: %v", err) + } + portNum, err := strconv.Atoi(port) + if err != nil { + t.Fatalf("parse listener port: %v", err) + } + + p := &Proxy{ + config: &Config{ + HTTP: HTTPConfig{ + Enabled: true, + Host: host, + Port: uint16(portNum), + }, + }, + done: make(chan struct{}), + } + + if ok := p.startHTTP(); ok { + t.Fatal("expected HTTP server start to fail when the port is already in use") + } +} diff --git a/state_impl.go b/state_impl.go index 0415538..92cae92 100644 --- a/state_impl.go +++ b/state_impl.go @@ -220,7 +220,10 @@ func (p *Proxy) Start() { p.watcher.Start() } if p.config.HTTP.Enabled { - p.startHTTP() + if !p.startHTTP() { + p.Stop() + return + } } p.ticker = time.NewTicker(time.Second) go func() { @@ -585,7 +588,10 @@ func parseTLSVersion(value string) uint16 { } } -func (p *Proxy) startHTTP() { +func (p *Proxy) startHTTP() bool { + if p == nil || !p.config.HTTP.Enabled { + return true + } mux := http.NewServeMux() mux.HandleFunc("/1/summary", func(w http.ResponseWriter, r *http.Request) { if status, ok := p.allowHTTP(r); !ok { @@ -618,10 +624,18 @@ func (p *Proxy) startHTTP() { p.writeJSON(w, p.minersDocument()) }) addr := net.JoinHostPort(p.config.HTTP.Host, strconv.Itoa(int(p.config.HTTP.Port))) + listener, err := net.Listen("tcp", addr) + if err != nil { + return false + } p.httpServer = &http.Server{Addr: addr, Handler: mux} go func() { - _ = p.httpServer.ListenAndServe() + err := p.httpServer.Serve(listener) + if err != nil && !errors.Is(err, http.ErrServerClosed) { + p.Stop() + } }() + return true } func (p *Proxy) allowHTTP(r *http.Request) (int, bool) {