fix(api): preserve streaming response passthrough

Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
Virgil 2026-04-02 06:04:06 +00:00
parent d7ef3610f7
commit 29f4c23977
3 changed files with 178 additions and 3 deletions

View file

@ -25,6 +25,8 @@ type responseMetaRecorder struct {
body bytes.Buffer
status int
wroteHeader bool
committed bool
passthrough bool
}
func newResponseMetaRecorder(w gin.ResponseWriter) *responseMetaRecorder {
@ -45,15 +47,32 @@ func (w *responseMetaRecorder) Header() http.Header {
}
func (w *responseMetaRecorder) WriteHeader(code int) {
if w.passthrough {
w.status = code
w.wroteHeader = true
w.ResponseWriter.WriteHeader(code)
return
}
w.status = code
w.wroteHeader = true
}
func (w *responseMetaRecorder) WriteHeaderNow() {
if w.passthrough {
w.wroteHeader = true
w.ResponseWriter.WriteHeaderNow()
return
}
w.wroteHeader = true
}
func (w *responseMetaRecorder) Write(data []byte) (int, error) {
if w.passthrough {
if !w.wroteHeader {
w.WriteHeader(http.StatusOK)
}
return w.ResponseWriter.Write(data)
}
if !w.wroteHeader {
w.WriteHeader(http.StatusOK)
}
@ -61,6 +80,12 @@ func (w *responseMetaRecorder) Write(data []byte) (int, error) {
}
func (w *responseMetaRecorder) WriteString(s string) (int, error) {
if w.passthrough {
if !w.wroteHeader {
w.WriteHeader(http.StatusOK)
}
return w.ResponseWriter.WriteString(s)
}
if !w.wroteHeader {
w.WriteHeader(http.StatusOK)
}
@ -68,6 +93,23 @@ func (w *responseMetaRecorder) WriteString(s string) (int, error) {
}
func (w *responseMetaRecorder) Flush() {
if w.passthrough {
if f, ok := w.ResponseWriter.(http.Flusher); ok {
f.Flush()
}
return
}
if !w.wroteHeader {
w.WriteHeader(http.StatusOK)
}
w.commit(true)
w.passthrough = true
if f, ok := w.ResponseWriter.(http.Flusher); ok {
f.Flush()
}
}
func (w *responseMetaRecorder) Status() int {
@ -87,10 +129,27 @@ func (w *responseMetaRecorder) Written() bool {
}
func (w *responseMetaRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) {
if w.passthrough {
if h, ok := w.ResponseWriter.(http.Hijacker); ok {
return h.Hijack()
}
return nil, nil, io.ErrClosedPipe
}
w.wroteHeader = true
w.passthrough = true
if h, ok := w.ResponseWriter.(http.Hijacker); ok {
return h.Hijack()
}
return nil, nil, io.ErrClosedPipe
}
func (w *responseMetaRecorder) commit() {
func (w *responseMetaRecorder) commit(writeBody bool) {
if w.committed {
return
}
for k := range w.ResponseWriter.Header() {
w.ResponseWriter.Header().Del(k)
}
@ -102,7 +161,11 @@ func (w *responseMetaRecorder) commit() {
}
w.ResponseWriter.WriteHeader(w.Status())
_, _ = w.ResponseWriter.Write(w.body.Bytes())
if writeBody {
_, _ = w.ResponseWriter.Write(w.body.Bytes())
w.body.Reset()
}
w.committed = true
}
// responseMetaMiddleware injects request metadata into JSON envelope
@ -118,6 +181,10 @@ func responseMetaMiddleware() gin.HandlerFunc {
c.Next()
if recorder.passthrough {
return
}
body := recorder.body.Bytes()
if meta := GetRequestMeta(c); meta != nil && shouldAttachResponseMeta(recorder.Header().Get("Content-Type"), body) {
if refreshed := refreshResponseMetaBody(body, meta); refreshed != nil {
@ -128,7 +195,7 @@ func responseMetaMiddleware() gin.HandlerFunc {
recorder.body.Reset()
_, _ = recorder.body.Write(body)
recorder.Header().Set("Content-Length", strconv.Itoa(len(body)))
recorder.commit()
recorder.commit(true)
}
}

View file

@ -242,6 +242,67 @@ func TestWithSSE_Good_CombinesWithOtherMiddleware(t *testing.T) {
}
}
func TestWithSSE_Good_WithResponseMetaStillStreamsEvents(t *testing.T) {
gin.SetMode(gin.TestMode)
broker := api.NewSSEBroker()
e, err := api.New(
api.WithRequestID(),
api.WithResponseMeta(),
api.WithSSE(broker),
)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
srv := httptest.NewServer(e.Handler())
defer srv.Close()
resp, err := http.Get(srv.URL + "/events")
if err != nil {
t.Fatalf("request failed: %v", err)
}
defer resp.Body.Close()
if ct := resp.Header.Get("Content-Type"); !strings.HasPrefix(ct, "text/event-stream") {
t.Fatalf("expected Content-Type starting with text/event-stream, got %q", ct)
}
if reqID := resp.Header.Get("X-Request-ID"); reqID == "" {
t.Fatal("expected X-Request-ID header from RequestID middleware")
}
waitForClients(t, broker, 1)
broker.Publish("test", "greeting", map[string]string{"msg": "hello"})
scanner := bufio.NewScanner(resp.Body)
var eventLine string
deadline := time.After(3 * time.Second)
done := make(chan struct{})
go func() {
defer close(done)
for scanner.Scan() {
line := scanner.Text()
if after, ok := strings.CutPrefix(line, "event: "); ok {
eventLine = after
return
}
}
}()
select {
case <-done:
case <-deadline:
t.Fatal("timed out waiting for SSE event with response meta enabled")
}
if eventLine != "greeting" {
t.Fatalf("expected event=%q, got %q", "greeting", eventLine)
}
}
func TestWithSSE_Good_MultipleClients(t *testing.T) {
gin.SetMode(gin.TestMode)

View file

@ -118,6 +118,53 @@ func TestWSEndpoint_Good_CustomPath(t *testing.T) {
}
}
func TestWSEndpoint_Good_WithResponseMeta(t *testing.T) {
gin.SetMode(gin.TestMode)
upgrader := websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool { return true },
}
wsHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
t.Logf("upgrade error: %v", err)
return
}
defer conn.Close()
_ = conn.WriteMessage(websocket.TextMessage, []byte("meta"))
})
e, err := api.New(
api.WithRequestID(),
api.WithResponseMeta(),
api.WithWSHandler(wsHandler),
)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
srv := httptest.NewServer(e.Handler())
defer srv.Close()
wsURL := "ws" + strings.TrimPrefix(srv.URL, "http") + "/ws"
conn, resp, err := websocket.DefaultDialer.Dial(wsURL, nil)
if err != nil {
if resp != nil {
t.Fatalf("failed to dial WebSocket: %v (status=%d)", err, resp.StatusCode)
}
t.Fatalf("failed to dial WebSocket: %v", err)
}
defer conn.Close()
_, msg, err := conn.ReadMessage()
if err != nil {
t.Fatalf("failed to read message: %v", err)
}
if string(msg) != "meta" {
t.Fatalf("expected message=%q, got %q", "meta", string(msg))
}
}
func TestNoWSHandler_Good(t *testing.T) {
gin.SetMode(gin.TestMode)