traefik/pkg/middlewares/response_modifier.go
2024-01-15 16:14:05 +01:00

106 lines
2.5 KiB
Go

package middlewares
import (
"bufio"
"fmt"
"net"
"net/http"
"github.com/rs/zerolog/log"
)
// ResponseModifier is a ResponseWriter to modify the response headers before sending them.
type ResponseModifier struct {
req *http.Request
rw http.ResponseWriter
headersSent bool // whether headers have already been sent
code int // status code, must default to 200
modifier func(*http.Response) error // can be nil
modified bool // whether modifier has already been called for the current request
modifierErr error // returned by modifier call
}
// NewResponseModifier returns a new ResponseModifier instance.
// The given modifier can be nil.
func NewResponseModifier(w http.ResponseWriter, r *http.Request, modifier func(*http.Response) error) http.ResponseWriter {
return &ResponseModifier{
req: r,
rw: w,
modifier: modifier,
code: http.StatusOK,
}
}
// WriteHeader is, in the specific case of 1xx status codes, a direct call to the wrapped ResponseWriter, without marking headers as sent,
// allowing so further calls.
func (r *ResponseModifier) WriteHeader(code int) {
if r.headersSent {
return
}
// Handling informational headers.
if code >= 100 && code <= 199 {
r.rw.WriteHeader(code)
return
}
defer func() {
r.code = code
r.headersSent = true
}()
if r.modifier == nil || r.modified {
r.rw.WriteHeader(code)
return
}
resp := http.Response{
Header: r.rw.Header(),
Request: r.req,
}
if err := r.modifier(&resp); err != nil {
r.modifierErr = err
// we are propagating when we are called in Write, but we're logging anyway,
// because we could be called from another place which does not take care of
// checking w.modifierErr.
log.Error().Err(err).Msg("Error when applying response modifier")
r.rw.WriteHeader(http.StatusInternalServerError)
return
}
r.modified = true
r.rw.WriteHeader(code)
}
func (r *ResponseModifier) Header() http.Header {
return r.rw.Header()
}
func (r *ResponseModifier) Write(b []byte) (int, error) {
r.WriteHeader(r.code)
if r.modifierErr != nil {
return 0, r.modifierErr
}
return r.rw.Write(b)
}
// Hijack hijacks the connection.
func (r *ResponseModifier) Hijack() (net.Conn, *bufio.ReadWriter, error) {
if h, ok := r.rw.(http.Hijacker); ok {
return h.Hijack()
}
return nil, nil, fmt.Errorf("not a hijacker: %T", r.rw)
}
// Flush sends any buffered data to the client.
func (r *ResponseModifier) Flush() {
if flusher, ok := r.rw.(http.Flusher); ok {
flusher.Flush()
}
}