diff --git a/http_auth_test.go b/http_auth_test.go index 5461792..3c502d7 100644 --- a/http_auth_test.go +++ b/http_auth_test.go @@ -88,6 +88,18 @@ func TestProxy_allowHTTP_Ugly(t *testing.T) { } } +func TestProxy_allowHTTP_NilConfig_Ugly(t *testing.T) { + p := &Proxy{} + + status, ok := p.allowMonitoringRequest(&http.Request{Method: http.MethodGet}) + if ok { + t.Fatal("expected nil config request to be rejected") + } + if status != http.StatusServiceUnavailable { + t.Fatalf("expected status %d, got %d", http.StatusServiceUnavailable, status) + } +} + func TestProxy_startHTTP_Good(t *testing.T) { p := &Proxy{ config: &Config{ @@ -106,6 +118,14 @@ func TestProxy_startHTTP_Good(t *testing.T) { p.Stop() } +func TestProxy_startHTTP_NilConfig_Bad(t *testing.T) { + p := &Proxy{} + + if ok := p.startMonitoringServer(); ok { + t.Fatal("expected nil config to skip HTTP server start") + } +} + func TestProxy_startHTTP_Bad(t *testing.T) { listener, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { diff --git a/state_impl.go b/state_impl.go index ddda890..d1b5a99 100644 --- a/state_impl.go +++ b/state_impl.go @@ -192,6 +192,9 @@ func (p *Proxy) Start() { if p == nil { return } + if p.config == nil { + return + } for _, bind := range p.config.Bind { var tlsCfg *tls.Config if bind.TLS { @@ -590,8 +593,8 @@ func parseTLSVersion(value string) uint16 { } func (p *Proxy) startMonitoringServer() bool { - if p == nil || !p.config.HTTP.Enabled { - return true + if p == nil || p.config == nil || !p.config.HTTP.Enabled { + return false } mux := http.NewServeMux() mux.HandleFunc("/1/summary", func(w http.ResponseWriter, r *http.Request) { @@ -640,10 +643,10 @@ func (p *Proxy) startMonitoringServer() bool { } func (p *Proxy) allowMonitoringRequest(r *http.Request) (int, bool) { - if p == nil { + if p == nil || p.config == nil { return http.StatusServiceUnavailable, false } - if p.config != nil && p.config.HTTP.Restricted && r.Method != http.MethodGet { + if p.config.HTTP.Restricted && r.Method != http.MethodGet { return http.StatusMethodNotAllowed, false } if token := p.config.HTTP.AccessToken; token != "" {