diff --git a/http_auth_test.go b/http_auth_test.go new file mode 100644 index 0000000..8fed0fe --- /dev/null +++ b/http_auth_test.go @@ -0,0 +1,71 @@ +package proxy + +import ( + "net/http" + "testing" +) + +func TestProxy_allowHTTP_Good(t *testing.T) { + p := &Proxy{ + config: &Config{ + HTTP: HTTPConfig{ + Restricted: true, + AccessToken: "secret", + }, + }, + } + + status, ok := p.allowHTTP(&http.Request{ + Method: http.MethodGet, + Header: http.Header{ + "Authorization": []string{"Bearer secret"}, + }, + }) + if !ok { + t.Fatalf("expected authorised request to pass, got status %d", status) + } + if status != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, status) + } +} + +func TestProxy_allowHTTP_Bad(t *testing.T) { + p := &Proxy{ + config: &Config{ + HTTP: HTTPConfig{ + Restricted: true, + }, + }, + } + + status, ok := p.allowHTTP(&http.Request{Method: http.MethodPost}) + if ok { + t.Fatal("expected non-GET request to be rejected") + } + if status != http.StatusMethodNotAllowed { + t.Fatalf("expected status %d, got %d", http.StatusMethodNotAllowed, status) + } +} + +func TestProxy_allowHTTP_Ugly(t *testing.T) { + p := &Proxy{ + config: &Config{ + HTTP: HTTPConfig{ + AccessToken: "secret", + }, + }, + } + + status, ok := p.allowHTTP(&http.Request{ + Method: http.MethodGet, + Header: http.Header{ + "Authorization": []string{"Bearer wrong"}, + }, + }) + if ok { + t.Fatal("expected invalid token to be rejected") + } + if status != http.StatusUnauthorized { + t.Fatalf("expected status %d, got %d", http.StatusUnauthorized, status) + } +} diff --git a/state_impl.go b/state_impl.go index 3874e73..0415538 100644 --- a/state_impl.go +++ b/state_impl.go @@ -588,22 +588,31 @@ func parseTLSVersion(value string) uint16 { func (p *Proxy) startHTTP() { mux := http.NewServeMux() mux.HandleFunc("/1/summary", func(w http.ResponseWriter, r *http.Request) { - if !p.allowHTTP(r) { - w.WriteHeader(http.StatusMethodNotAllowed) + if status, ok := p.allowHTTP(r); !ok { + if status == http.StatusUnauthorized { + w.Header().Set("WWW-Authenticate", "Bearer") + } + w.WriteHeader(status) return } p.writeJSON(w, p.summaryDocument()) }) mux.HandleFunc("/1/workers", func(w http.ResponseWriter, r *http.Request) { - if !p.allowHTTP(r) { - w.WriteHeader(http.StatusMethodNotAllowed) + if status, ok := p.allowHTTP(r); !ok { + if status == http.StatusUnauthorized { + w.Header().Set("WWW-Authenticate", "Bearer") + } + w.WriteHeader(status) return } p.writeJSON(w, p.workersDocument()) }) mux.HandleFunc("/1/miners", func(w http.ResponseWriter, r *http.Request) { - if !p.allowHTTP(r) { - w.WriteHeader(http.StatusMethodNotAllowed) + if status, ok := p.allowHTTP(r); !ok { + if status == http.StatusUnauthorized { + w.Header().Set("WWW-Authenticate", "Bearer") + } + w.WriteHeader(status) return } p.writeJSON(w, p.minersDocument()) @@ -615,20 +624,20 @@ func (p *Proxy) startHTTP() { }() } -func (p *Proxy) allowHTTP(r *http.Request) bool { +func (p *Proxy) allowHTTP(r *http.Request) (int, bool) { if p == nil { - return false + return http.StatusServiceUnavailable, false } if p.config.HTTP.Restricted && r.Method != http.MethodGet { - return false + return http.StatusMethodNotAllowed, false } if token := p.config.HTTP.AccessToken; token != "" { parts := strings.SplitN(r.Header.Get("Authorization"), " ", 2) if len(parts) != 2 || !strings.EqualFold(parts[0], "bearer") || parts[1] != token { - return false + return http.StatusUnauthorized, false } } - return true + return http.StatusOK, true } func (p *Proxy) writeJSON(w http.ResponseWriter, payload any) {