fix(api): preserve streaming response passthrough
Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
parent
d7ef3610f7
commit
29f4c23977
3 changed files with 178 additions and 3 deletions
|
|
@ -25,6 +25,8 @@ type responseMetaRecorder struct {
|
|||
body bytes.Buffer
|
||||
status int
|
||||
wroteHeader bool
|
||||
committed bool
|
||||
passthrough bool
|
||||
}
|
||||
|
||||
func newResponseMetaRecorder(w gin.ResponseWriter) *responseMetaRecorder {
|
||||
|
|
@ -45,15 +47,32 @@ func (w *responseMetaRecorder) Header() http.Header {
|
|||
}
|
||||
|
||||
func (w *responseMetaRecorder) WriteHeader(code int) {
|
||||
if w.passthrough {
|
||||
w.status = code
|
||||
w.wroteHeader = true
|
||||
w.ResponseWriter.WriteHeader(code)
|
||||
return
|
||||
}
|
||||
w.status = code
|
||||
w.wroteHeader = true
|
||||
}
|
||||
|
||||
func (w *responseMetaRecorder) WriteHeaderNow() {
|
||||
if w.passthrough {
|
||||
w.wroteHeader = true
|
||||
w.ResponseWriter.WriteHeaderNow()
|
||||
return
|
||||
}
|
||||
w.wroteHeader = true
|
||||
}
|
||||
|
||||
func (w *responseMetaRecorder) Write(data []byte) (int, error) {
|
||||
if w.passthrough {
|
||||
if !w.wroteHeader {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
return w.ResponseWriter.Write(data)
|
||||
}
|
||||
if !w.wroteHeader {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
|
@ -61,6 +80,12 @@ func (w *responseMetaRecorder) Write(data []byte) (int, error) {
|
|||
}
|
||||
|
||||
func (w *responseMetaRecorder) WriteString(s string) (int, error) {
|
||||
if w.passthrough {
|
||||
if !w.wroteHeader {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
return w.ResponseWriter.WriteString(s)
|
||||
}
|
||||
if !w.wroteHeader {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
|
@ -68,6 +93,23 @@ func (w *responseMetaRecorder) WriteString(s string) (int, error) {
|
|||
}
|
||||
|
||||
func (w *responseMetaRecorder) Flush() {
|
||||
if w.passthrough {
|
||||
if f, ok := w.ResponseWriter.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if !w.wroteHeader {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
w.commit(true)
|
||||
w.passthrough = true
|
||||
|
||||
if f, ok := w.ResponseWriter.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
func (w *responseMetaRecorder) Status() int {
|
||||
|
|
@ -87,10 +129,27 @@ func (w *responseMetaRecorder) Written() bool {
|
|||
}
|
||||
|
||||
func (w *responseMetaRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
if w.passthrough {
|
||||
if h, ok := w.ResponseWriter.(http.Hijacker); ok {
|
||||
return h.Hijack()
|
||||
}
|
||||
return nil, nil, io.ErrClosedPipe
|
||||
}
|
||||
|
||||
w.wroteHeader = true
|
||||
w.passthrough = true
|
||||
|
||||
if h, ok := w.ResponseWriter.(http.Hijacker); ok {
|
||||
return h.Hijack()
|
||||
}
|
||||
return nil, nil, io.ErrClosedPipe
|
||||
}
|
||||
|
||||
func (w *responseMetaRecorder) commit() {
|
||||
func (w *responseMetaRecorder) commit(writeBody bool) {
|
||||
if w.committed {
|
||||
return
|
||||
}
|
||||
|
||||
for k := range w.ResponseWriter.Header() {
|
||||
w.ResponseWriter.Header().Del(k)
|
||||
}
|
||||
|
|
@ -102,7 +161,11 @@ func (w *responseMetaRecorder) commit() {
|
|||
}
|
||||
|
||||
w.ResponseWriter.WriteHeader(w.Status())
|
||||
_, _ = w.ResponseWriter.Write(w.body.Bytes())
|
||||
if writeBody {
|
||||
_, _ = w.ResponseWriter.Write(w.body.Bytes())
|
||||
w.body.Reset()
|
||||
}
|
||||
w.committed = true
|
||||
}
|
||||
|
||||
// responseMetaMiddleware injects request metadata into JSON envelope
|
||||
|
|
@ -118,6 +181,10 @@ func responseMetaMiddleware() gin.HandlerFunc {
|
|||
|
||||
c.Next()
|
||||
|
||||
if recorder.passthrough {
|
||||
return
|
||||
}
|
||||
|
||||
body := recorder.body.Bytes()
|
||||
if meta := GetRequestMeta(c); meta != nil && shouldAttachResponseMeta(recorder.Header().Get("Content-Type"), body) {
|
||||
if refreshed := refreshResponseMetaBody(body, meta); refreshed != nil {
|
||||
|
|
@ -128,7 +195,7 @@ func responseMetaMiddleware() gin.HandlerFunc {
|
|||
recorder.body.Reset()
|
||||
_, _ = recorder.body.Write(body)
|
||||
recorder.Header().Set("Content-Length", strconv.Itoa(len(body)))
|
||||
recorder.commit()
|
||||
recorder.commit(true)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
61
sse_test.go
61
sse_test.go
|
|
@ -242,6 +242,67 @@ func TestWithSSE_Good_CombinesWithOtherMiddleware(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestWithSSE_Good_WithResponseMetaStillStreamsEvents(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
broker := api.NewSSEBroker()
|
||||
e, err := api.New(
|
||||
api.WithRequestID(),
|
||||
api.WithResponseMeta(),
|
||||
api.WithSSE(broker),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
srv := httptest.NewServer(e.Handler())
|
||||
defer srv.Close()
|
||||
|
||||
resp, err := http.Get(srv.URL + "/events")
|
||||
if err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if ct := resp.Header.Get("Content-Type"); !strings.HasPrefix(ct, "text/event-stream") {
|
||||
t.Fatalf("expected Content-Type starting with text/event-stream, got %q", ct)
|
||||
}
|
||||
if reqID := resp.Header.Get("X-Request-ID"); reqID == "" {
|
||||
t.Fatal("expected X-Request-ID header from RequestID middleware")
|
||||
}
|
||||
|
||||
waitForClients(t, broker, 1)
|
||||
|
||||
broker.Publish("test", "greeting", map[string]string{"msg": "hello"})
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
var eventLine string
|
||||
|
||||
deadline := time.After(3 * time.Second)
|
||||
done := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
defer close(done)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if after, ok := strings.CutPrefix(line, "event: "); ok {
|
||||
eventLine = after
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-deadline:
|
||||
t.Fatal("timed out waiting for SSE event with response meta enabled")
|
||||
}
|
||||
|
||||
if eventLine != "greeting" {
|
||||
t.Fatalf("expected event=%q, got %q", "greeting", eventLine)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithSSE_Good_MultipleClients(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
|
|
|
|||
|
|
@ -118,6 +118,53 @@ func TestWSEndpoint_Good_CustomPath(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestWSEndpoint_Good_WithResponseMeta(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
upgrader := websocket.Upgrader{
|
||||
CheckOrigin: func(r *http.Request) bool { return true },
|
||||
}
|
||||
wsHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
t.Logf("upgrade error: %v", err)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
_ = conn.WriteMessage(websocket.TextMessage, []byte("meta"))
|
||||
})
|
||||
|
||||
e, err := api.New(
|
||||
api.WithRequestID(),
|
||||
api.WithResponseMeta(),
|
||||
api.WithWSHandler(wsHandler),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
srv := httptest.NewServer(e.Handler())
|
||||
defer srv.Close()
|
||||
|
||||
wsURL := "ws" + strings.TrimPrefix(srv.URL, "http") + "/ws"
|
||||
conn, resp, err := websocket.DefaultDialer.Dial(wsURL, nil)
|
||||
if err != nil {
|
||||
if resp != nil {
|
||||
t.Fatalf("failed to dial WebSocket: %v (status=%d)", err, resp.StatusCode)
|
||||
}
|
||||
t.Fatalf("failed to dial WebSocket: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
_, msg, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read message: %v", err)
|
||||
}
|
||||
if string(msg) != "meta" {
|
||||
t.Fatalf("expected message=%q, got %q", "meta", string(msg))
|
||||
}
|
||||
}
|
||||
|
||||
func TestNoWSHandler_Good(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue