From 29f4c23977b7b759e8682eaa54f804ea67efdf94 Mon Sep 17 00:00:00 2001 From: Virgil Date: Thu, 2 Apr 2026 06:04:06 +0000 Subject: [PATCH] fix(api): preserve streaming response passthrough Co-Authored-By: Virgil --- response_meta.go | 73 +++++++++++++++++++++++++++++++++++++++++++++-- sse_test.go | 61 +++++++++++++++++++++++++++++++++++++++ websocket_test.go | 47 ++++++++++++++++++++++++++++++ 3 files changed, 178 insertions(+), 3 deletions(-) diff --git a/response_meta.go b/response_meta.go index 0ca66bd..8438a7c 100644 --- a/response_meta.go +++ b/response_meta.go @@ -25,6 +25,8 @@ type responseMetaRecorder struct { body bytes.Buffer status int wroteHeader bool + committed bool + passthrough bool } func newResponseMetaRecorder(w gin.ResponseWriter) *responseMetaRecorder { @@ -45,15 +47,32 @@ func (w *responseMetaRecorder) Header() http.Header { } func (w *responseMetaRecorder) WriteHeader(code int) { + if w.passthrough { + w.status = code + w.wroteHeader = true + w.ResponseWriter.WriteHeader(code) + return + } w.status = code w.wroteHeader = true } func (w *responseMetaRecorder) WriteHeaderNow() { + if w.passthrough { + w.wroteHeader = true + w.ResponseWriter.WriteHeaderNow() + return + } w.wroteHeader = true } func (w *responseMetaRecorder) Write(data []byte) (int, error) { + if w.passthrough { + if !w.wroteHeader { + w.WriteHeader(http.StatusOK) + } + return w.ResponseWriter.Write(data) + } if !w.wroteHeader { w.WriteHeader(http.StatusOK) } @@ -61,6 +80,12 @@ func (w *responseMetaRecorder) Write(data []byte) (int, error) { } func (w *responseMetaRecorder) WriteString(s string) (int, error) { + if w.passthrough { + if !w.wroteHeader { + w.WriteHeader(http.StatusOK) + } + return w.ResponseWriter.WriteString(s) + } if !w.wroteHeader { w.WriteHeader(http.StatusOK) } @@ -68,6 +93,23 @@ func (w *responseMetaRecorder) WriteString(s string) (int, error) { } func (w *responseMetaRecorder) Flush() { + if w.passthrough { + if f, ok := w.ResponseWriter.(http.Flusher); ok { + f.Flush() + } + return + } + + if !w.wroteHeader { + w.WriteHeader(http.StatusOK) + } + + w.commit(true) + w.passthrough = true + + if f, ok := w.ResponseWriter.(http.Flusher); ok { + f.Flush() + } } func (w *responseMetaRecorder) Status() int { @@ -87,10 +129,27 @@ func (w *responseMetaRecorder) Written() bool { } func (w *responseMetaRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) { + if w.passthrough { + if h, ok := w.ResponseWriter.(http.Hijacker); ok { + return h.Hijack() + } + return nil, nil, io.ErrClosedPipe + } + + w.wroteHeader = true + w.passthrough = true + + if h, ok := w.ResponseWriter.(http.Hijacker); ok { + return h.Hijack() + } return nil, nil, io.ErrClosedPipe } -func (w *responseMetaRecorder) commit() { +func (w *responseMetaRecorder) commit(writeBody bool) { + if w.committed { + return + } + for k := range w.ResponseWriter.Header() { w.ResponseWriter.Header().Del(k) } @@ -102,7 +161,11 @@ func (w *responseMetaRecorder) commit() { } w.ResponseWriter.WriteHeader(w.Status()) - _, _ = w.ResponseWriter.Write(w.body.Bytes()) + if writeBody { + _, _ = w.ResponseWriter.Write(w.body.Bytes()) + w.body.Reset() + } + w.committed = true } // responseMetaMiddleware injects request metadata into JSON envelope @@ -118,6 +181,10 @@ func responseMetaMiddleware() gin.HandlerFunc { c.Next() + if recorder.passthrough { + return + } + body := recorder.body.Bytes() if meta := GetRequestMeta(c); meta != nil && shouldAttachResponseMeta(recorder.Header().Get("Content-Type"), body) { if refreshed := refreshResponseMetaBody(body, meta); refreshed != nil { @@ -128,7 +195,7 @@ func responseMetaMiddleware() gin.HandlerFunc { recorder.body.Reset() _, _ = recorder.body.Write(body) recorder.Header().Set("Content-Length", strconv.Itoa(len(body))) - recorder.commit() + recorder.commit(true) } } diff --git a/sse_test.go b/sse_test.go index 7276c6d..cfa950c 100644 --- a/sse_test.go +++ b/sse_test.go @@ -242,6 +242,67 @@ func TestWithSSE_Good_CombinesWithOtherMiddleware(t *testing.T) { } } +func TestWithSSE_Good_WithResponseMetaStillStreamsEvents(t *testing.T) { + gin.SetMode(gin.TestMode) + + broker := api.NewSSEBroker() + e, err := api.New( + api.WithRequestID(), + api.WithResponseMeta(), + api.WithSSE(broker), + ) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + srv := httptest.NewServer(e.Handler()) + defer srv.Close() + + resp, err := http.Get(srv.URL + "/events") + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + + if ct := resp.Header.Get("Content-Type"); !strings.HasPrefix(ct, "text/event-stream") { + t.Fatalf("expected Content-Type starting with text/event-stream, got %q", ct) + } + if reqID := resp.Header.Get("X-Request-ID"); reqID == "" { + t.Fatal("expected X-Request-ID header from RequestID middleware") + } + + waitForClients(t, broker, 1) + + broker.Publish("test", "greeting", map[string]string{"msg": "hello"}) + + scanner := bufio.NewScanner(resp.Body) + var eventLine string + + deadline := time.After(3 * time.Second) + done := make(chan struct{}) + + go func() { + defer close(done) + for scanner.Scan() { + line := scanner.Text() + if after, ok := strings.CutPrefix(line, "event: "); ok { + eventLine = after + return + } + } + }() + + select { + case <-done: + case <-deadline: + t.Fatal("timed out waiting for SSE event with response meta enabled") + } + + if eventLine != "greeting" { + t.Fatalf("expected event=%q, got %q", "greeting", eventLine) + } +} + func TestWithSSE_Good_MultipleClients(t *testing.T) { gin.SetMode(gin.TestMode) diff --git a/websocket_test.go b/websocket_test.go index d287364..5d950b4 100644 --- a/websocket_test.go +++ b/websocket_test.go @@ -118,6 +118,53 @@ func TestWSEndpoint_Good_CustomPath(t *testing.T) { } } +func TestWSEndpoint_Good_WithResponseMeta(t *testing.T) { + gin.SetMode(gin.TestMode) + + upgrader := websocket.Upgrader{ + CheckOrigin: func(r *http.Request) bool { return true }, + } + wsHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + t.Logf("upgrade error: %v", err) + return + } + defer conn.Close() + _ = conn.WriteMessage(websocket.TextMessage, []byte("meta")) + }) + + e, err := api.New( + api.WithRequestID(), + api.WithResponseMeta(), + api.WithWSHandler(wsHandler), + ) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + srv := httptest.NewServer(e.Handler()) + defer srv.Close() + + wsURL := "ws" + strings.TrimPrefix(srv.URL, "http") + "/ws" + conn, resp, err := websocket.DefaultDialer.Dial(wsURL, nil) + if err != nil { + if resp != nil { + t.Fatalf("failed to dial WebSocket: %v (status=%d)", err, resp.StatusCode) + } + t.Fatalf("failed to dial WebSocket: %v", err) + } + defer conn.Close() + + _, msg, err := conn.ReadMessage() + if err != nil { + t.Fatalf("failed to read message: %v", err) + } + if string(msg) != "meta" { + t.Fatalf("expected message=%q, got %q", "meta", string(msg)) + } +} + func TestNoWSHandler_Good(t *testing.T) { gin.SetMode(gin.TestMode)