go-api/brotli.go
Snider 64a8b16ca2 feat: add WithBrotli response compression middleware
Co-Authored-By: Virgil <virgil@lethean.io>
2026-02-20 23:37:46 +00:00

120 lines
2.7 KiB
Go

// SPDX-License-Identifier: EUPL-1.2
package api
import (
"io"
"net/http"
"strconv"
"strings"
"sync"
"github.com/andybalholm/brotli"
"github.com/gin-gonic/gin"
)
const (
// BrotliBestSpeed is the lowest (fastest) Brotli compression level.
BrotliBestSpeed = brotli.BestSpeed
// BrotliBestCompression is the highest (smallest output) Brotli level.
BrotliBestCompression = brotli.BestCompression
// BrotliDefaultCompression is the default Brotli compression level.
BrotliDefaultCompression = brotli.DefaultCompression
)
// brotliHandler manages a pool of brotli writers for reuse across requests.
type brotliHandler struct {
pool sync.Pool
level int
}
// newBrotliHandler creates a handler that pools brotli writers at the given level.
func newBrotliHandler(level int) *brotliHandler {
if level < BrotliBestSpeed || level > BrotliBestCompression {
level = BrotliDefaultCompression
}
return &brotliHandler{
level: level,
pool: sync.Pool{
New: func() any {
return brotli.NewWriterLevel(io.Discard, level)
},
},
}
}
// Handle is the Gin middleware function that compresses responses with Brotli.
func (h *brotliHandler) Handle(c *gin.Context) {
if !strings.Contains(c.Request.Header.Get("Accept-Encoding"), "br") {
c.Next()
return
}
w := h.pool.Get().(*brotli.Writer)
w.Reset(c.Writer)
c.Header("Content-Encoding", "br")
c.Writer.Header().Add("Vary", "Accept-Encoding")
bw := &brotliWriter{ResponseWriter: c.Writer, writer: w}
c.Writer = bw
defer func() {
if bw.status >= http.StatusBadRequest {
bw.Header().Del("Content-Encoding")
bw.Header().Del("Vary")
w.Reset(io.Discard)
} else if c.Writer.Size() < 0 {
w.Reset(io.Discard)
}
_ = w.Close()
if c.Writer.Size() > -1 {
c.Header("Content-Length", strconv.Itoa(c.Writer.Size()))
}
h.pool.Put(w)
}()
c.Next()
}
// brotliWriter wraps gin.ResponseWriter to intercept writes through brotli.
type brotliWriter struct {
gin.ResponseWriter
writer *brotli.Writer
statusWritten bool
status int
}
func (b *brotliWriter) Write(data []byte) (int, error) {
b.Header().Del("Content-Length")
if !b.statusWritten {
b.status = b.ResponseWriter.Status()
}
if b.status >= http.StatusBadRequest {
b.Header().Del("Content-Encoding")
b.Header().Del("Vary")
return b.ResponseWriter.Write(data)
}
return b.writer.Write(data)
}
func (b *brotliWriter) WriteString(s string) (int, error) {
return b.Write([]byte(s))
}
func (b *brotliWriter) WriteHeader(code int) {
b.status = code
b.statusWritten = true
b.Header().Del("Content-Length")
b.ResponseWriter.WriteHeader(code)
}
func (b *brotliWriter) Flush() {
_ = b.writer.Flush()
b.ResponseWriter.Flush()
}