diff --git a/docs/content/reference/dynamic-configuration/file.toml b/docs/content/reference/dynamic-configuration/file.toml index 680e22a06..fd35f2b1e 100644 --- a/docs/content/reference/dynamic-configuration/file.toml +++ b/docs/content/reference/dynamic-configuration/file.toml @@ -65,6 +65,7 @@ [http.services.Service02] [http.services.Service02.mirroring] service = "foobar" + maxBodySize = 42 [[http.services.Service02.mirroring.mirrors]] name = "foobar" diff --git a/docs/content/reference/dynamic-configuration/file.yaml b/docs/content/reference/dynamic-configuration/file.yaml index deafc992a..432201fd7 100644 --- a/docs/content/reference/dynamic-configuration/file.yaml +++ b/docs/content/reference/dynamic-configuration/file.yaml @@ -72,6 +72,7 @@ http: Service02: mirroring: service: foobar + maxBodySize: 42 mirrors: - name: foobar percent: 42 diff --git a/docs/content/reference/dynamic-configuration/kubernetes-crd-resource.yml b/docs/content/reference/dynamic-configuration/kubernetes-crd-resource.yml index fd6e4ce79..4e6968fa5 100644 --- a/docs/content/reference/dynamic-configuration/kubernetes-crd-resource.yml +++ b/docs/content/reference/dynamic-configuration/kubernetes-crd-resource.yml @@ -65,6 +65,8 @@ spec: kind: TraefikService mirrors: - name: s2 + # Optional + maxBodySize: 2000000000 # Optional, as it is the default value kind: Service percent: 20 diff --git a/docs/content/reference/dynamic-configuration/kv-ref.md b/docs/content/reference/dynamic-configuration/kv-ref.md index 2ce8804e9..cbfa364c8 100644 --- a/docs/content/reference/dynamic-configuration/kv-ref.md +++ b/docs/content/reference/dynamic-configuration/kv-ref.md @@ -174,6 +174,7 @@ | `traefik/http/services/Service01/loadBalancer/sticky/cookie/httpOnly` | `true` | | `traefik/http/services/Service01/loadBalancer/sticky/cookie/name` | `foobar` | | `traefik/http/services/Service01/loadBalancer/sticky/cookie/secure` | `true` | +| `traefik/http/services/Service02/mirroring/maxBodySize` | `42` | | `traefik/http/services/Service02/mirroring/mirrors/0/name` | `foobar` | | `traefik/http/services/Service02/mirroring/mirrors/0/percent` | `42` | | `traefik/http/services/Service02/mirroring/mirrors/1/name` | `foobar` | diff --git a/docs/content/routing/services/index.md b/docs/content/routing/services/index.md index 067ff63a5..1f8abc936 100644 --- a/docs/content/routing/services/index.md +++ b/docs/content/routing/services/index.md @@ -462,6 +462,8 @@ http: ### Mirroring (service) The mirroring is able to mirror requests sent to a service to other services. +Please note that by default the whole request is buffered in memory while it is being mirrored. +See the maxBodySize option in the example below for how to modify this behaviour. !!! info "Supported Providers" @@ -473,6 +475,10 @@ The mirroring is able to mirror requests sent to a service to other services. [http.services.mirrored-api] [http.services.mirrored-api.mirroring] service = "appv1" + # maxBodySize is the maximum size in bytes allowed for the body of the request. + # If the body is larger, the request is not mirrored. + # Default value is -1, which means unlimited size. + maxBodySize = 1024 [[http.services.mirrored-api.mirroring.mirrors]] name = "appv2" percent = 10 @@ -495,6 +501,10 @@ http: mirrored-api: mirroring: service: appv1 + # maxBodySize is the maximum size allowed for the body of the request. + # If the body is larger, the request is not mirrored. + # Default value is -1, which means unlimited size. + maxBodySize = 1024 mirrors: - name: appv2 percent: 10 diff --git a/integration/fixtures/mirror.toml b/integration/fixtures/mirror.toml index 9448eec7a..55d0cf040 100644 --- a/integration/fixtures/mirror.toml +++ b/integration/fixtures/mirror.toml @@ -23,6 +23,11 @@ service = "mirror" rule = "Path(`/whoami`)" + [http.routers.router2] + service = "mirrorWithMaxBody" + rule = "Path(`/whoamiWithMaxBody`)" + + [http.services] [http.services.mirror.mirroring] service = "service1" @@ -33,6 +38,17 @@ name = "mirror2" percent = 50 + [http.services.mirrorWithMaxBody.mirroring] + service = "service1" + maxBodySize = 8 + [[http.services.mirrorWithMaxBody.mirroring.mirrors]] + name = "mirror1" + percent = 10 + [[http.services.mirrorWithMaxBody.mirroring.mirrors]] + name = "mirror2" + percent = 50 + + [http.services.service1.loadBalancer] [[http.services.service1.loadBalancer.servers]] url = "{{ .MainServer }}" diff --git a/integration/simple_test.go b/integration/simple_test.go index b72dd1088..4af8fdf39 100644 --- a/integration/simple_test.go +++ b/integration/simple_test.go @@ -2,6 +2,7 @@ package integration import ( "bytes" + "crypto/rand" "encoding/json" "fmt" "io/ioutil" @@ -777,6 +778,129 @@ func (s *SimpleSuite) TestMirror(c *check.C) { c.Assert(val2, checker.Equals, int32(5)) } +func (s *SimpleSuite) TestMirrorWithBody(c *check.C) { + var count, countMirror1, countMirror2 int32 + + body20 := make([]byte, 20) + _, err := rand.Read(body20) + c.Assert(err, checker.IsNil) + + body5 := make([]byte, 5) + _, err = rand.Read(body5) + c.Assert(err, checker.IsNil) + + verifyBody := func(req *http.Request) { + b, _ := ioutil.ReadAll(req.Body) + switch req.Header.Get("Size") { + case "20": + if !bytes.Equal(b, body20) { + c.Fatalf("Not Equals \n%v \n%v", body20, b) + } + case "5": + if !bytes.Equal(b, body5) { + c.Fatalf("Not Equals \n%v \n%v", body5, b) + } + default: + c.Fatal("Size header not present") + } + } + + main := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + verifyBody(req) + atomic.AddInt32(&count, 1) + })) + + mirror1 := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + verifyBody(req) + atomic.AddInt32(&countMirror1, 1) + })) + + mirror2 := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + verifyBody(req) + atomic.AddInt32(&countMirror2, 1) + })) + + mainServer := main.URL + mirror1Server := mirror1.URL + mirror2Server := mirror2.URL + + file := s.adaptFile(c, "fixtures/mirror.toml", struct { + MainServer string + Mirror1Server string + Mirror2Server string + }{MainServer: mainServer, Mirror1Server: mirror1Server, Mirror2Server: mirror2Server}) + defer os.Remove(file) + + cmd, output := s.traefikCmd(withConfigFile(file)) + defer output(c) + + err = cmd.Start() + c.Assert(err, checker.IsNil) + defer cmd.Process.Kill() + + err = try.GetRequest("http://127.0.0.1:8080/api/http/services", 1000*time.Millisecond, try.BodyContains("mirror1", "mirror2", "service1")) + c.Assert(err, checker.IsNil) + + req, err := http.NewRequest(http.MethodGet, "http://127.0.0.1:8000/whoami", bytes.NewBuffer(body20)) + c.Assert(err, checker.IsNil) + req.Header.Set("Size", "20") + for i := 0; i < 10; i++ { + response, err := http.DefaultClient.Do(req) + c.Assert(err, checker.IsNil) + c.Assert(response.StatusCode, checker.Equals, http.StatusOK) + } + + countTotal := atomic.LoadInt32(&count) + val1 := atomic.LoadInt32(&countMirror1) + val2 := atomic.LoadInt32(&countMirror2) + + c.Assert(countTotal, checker.Equals, int32(10)) + c.Assert(val1, checker.Equals, int32(1)) + c.Assert(val2, checker.Equals, int32(5)) + + atomic.StoreInt32(&count, 0) + atomic.StoreInt32(&countMirror1, 0) + atomic.StoreInt32(&countMirror2, 0) + + req, err = http.NewRequest(http.MethodGet, "http://127.0.0.1:8000/whoamiWithMaxBody", bytes.NewBuffer(body5)) + req.Header.Set("Size", "5") + c.Assert(err, checker.IsNil) + for i := 0; i < 10; i++ { + response, err := http.DefaultClient.Do(req) + c.Assert(err, checker.IsNil) + c.Assert(response.StatusCode, checker.Equals, http.StatusOK) + } + + countTotal = atomic.LoadInt32(&count) + val1 = atomic.LoadInt32(&countMirror1) + val2 = atomic.LoadInt32(&countMirror2) + + c.Assert(countTotal, checker.Equals, int32(10)) + c.Assert(val1, checker.Equals, int32(1)) + c.Assert(val2, checker.Equals, int32(5)) + + atomic.StoreInt32(&count, 0) + atomic.StoreInt32(&countMirror1, 0) + atomic.StoreInt32(&countMirror2, 0) + + req, err = http.NewRequest(http.MethodGet, "http://127.0.0.1:8000/whoamiWithMaxBody", bytes.NewBuffer(body20)) + req.Header.Set("Size", "20") + c.Assert(err, checker.IsNil) + for i := 0; i < 10; i++ { + response, err := http.DefaultClient.Do(req) + c.Assert(err, checker.IsNil) + c.Assert(response.StatusCode, checker.Equals, http.StatusOK) + } + + countTotal = atomic.LoadInt32(&count) + val1 = atomic.LoadInt32(&countMirror1) + val2 = atomic.LoadInt32(&countMirror2) + + c.Assert(countTotal, checker.Equals, int32(10)) + c.Assert(val1, checker.Equals, int32(0)) + c.Assert(val2, checker.Equals, int32(0)) +} + func (s *SimpleSuite) TestMirrorCanceled(c *check.C) { var count, countMirror1, countMirror2 int32 diff --git a/integration/testdata/rawdata-consul.json b/integration/testdata/rawdata-consul.json index 4f8e547bb..8f70cc3b4 100644 --- a/integration/testdata/rawdata-consul.json +++ b/integration/testdata/rawdata-consul.json @@ -152,6 +152,7 @@ "mirror@consul": { "mirroring": { "service": "simplesvc", + "maxBodySize": -1, "mirrors": [ { "name": "srvcA", diff --git a/integration/testdata/rawdata-etcd.json b/integration/testdata/rawdata-etcd.json index 2cae91a68..d8e844220 100644 --- a/integration/testdata/rawdata-etcd.json +++ b/integration/testdata/rawdata-etcd.json @@ -152,6 +152,7 @@ "mirror@etcd": { "mirroring": { "service": "simplesvc", + "maxBodySize": -1, "mirrors": [ { "name": "srvcA", diff --git a/integration/testdata/rawdata-redis.json b/integration/testdata/rawdata-redis.json index 46bf3bcb6..15a7a0396 100644 --- a/integration/testdata/rawdata-redis.json +++ b/integration/testdata/rawdata-redis.json @@ -152,6 +152,7 @@ "mirror@redis": { "mirroring": { "service": "simplesvc", + "maxBodySize": -1, "mirrors": [ { "name": "srvcA", diff --git a/integration/testdata/rawdata-zk.json b/integration/testdata/rawdata-zk.json index dd84a60fd..494833760 100644 --- a/integration/testdata/rawdata-zk.json +++ b/integration/testdata/rawdata-zk.json @@ -152,6 +152,7 @@ "mirror@zookeeper": { "mirroring": { "service": "simplesvc", + "maxBodySize": -1, "mirrors": [ { "name": "srvcA", diff --git a/pkg/config/dynamic/http_config.go b/pkg/config/dynamic/http_config.go index ad4f3dcc1..dc999fa1d 100644 --- a/pkg/config/dynamic/http_config.go +++ b/pkg/config/dynamic/http_config.go @@ -58,8 +58,15 @@ type RouterTLSConfig struct { // Mirroring holds the Mirroring configuration. type Mirroring struct { - Service string `json:"service,omitempty" toml:"service,omitempty" yaml:"service,omitempty"` - Mirrors []MirrorService `json:"mirrors,omitempty" toml:"mirrors,omitempty" yaml:"mirrors,omitempty"` + Service string `json:"service,omitempty" toml:"service,omitempty" yaml:"service,omitempty"` + MaxBodySize *int64 `json:"maxBodySize,omitempty" toml:"maxBodySize,omitempty" yaml:"maxBodySize,omitempty"` + Mirrors []MirrorService `json:"mirrors,omitempty" toml:"mirrors,omitempty" yaml:"mirrors,omitempty"` +} + +// SetDefaults Default values for a WRRService. +func (m *Mirroring) SetDefaults() { + var defaultMaxBodySize int64 = -1 + m.MaxBodySize = &defaultMaxBodySize } // +k8s:deepcopy-gen=true diff --git a/pkg/config/dynamic/zz_generated.deepcopy.go b/pkg/config/dynamic/zz_generated.deepcopy.go index 65c120ac4..33cbef1b1 100644 --- a/pkg/config/dynamic/zz_generated.deepcopy.go +++ b/pkg/config/dynamic/zz_generated.deepcopy.go @@ -762,6 +762,11 @@ func (in *MirrorService) DeepCopy() *MirrorService { // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. func (in *Mirroring) DeepCopyInto(out *Mirroring) { *out = *in + if in.MaxBodySize != nil { + in, out := &in.MaxBodySize, &out.MaxBodySize + *out = new(int64) + **out = **in + } if in.Mirrors != nil { in, out := &in.Mirrors, &out.Mirrors *out = make([]MirrorService, len(*in)) diff --git a/pkg/provider/kubernetes/crd/kubernetes_http.go b/pkg/provider/kubernetes/crd/kubernetes_http.go index 3582951a1..d77b18fc6 100644 --- a/pkg/provider/kubernetes/crd/kubernetes_http.go +++ b/pkg/provider/kubernetes/crd/kubernetes_http.go @@ -240,8 +240,9 @@ func (c configBuilder) buildMirroring(ctx context.Context, tService *v1alpha1.Tr conf[id] = &dynamic.Service{ Mirroring: &dynamic.Mirroring{ - Service: fullNameMain, - Mirrors: mirrorServices, + Service: fullNameMain, + Mirrors: mirrorServices, + MaxBodySize: tService.Spec.Mirroring.MaxBodySize, }, } diff --git a/pkg/provider/kubernetes/crd/traefik/v1alpha1/service.go b/pkg/provider/kubernetes/crd/traefik/v1alpha1/service.go index da29b6b5c..f86d6054d 100644 --- a/pkg/provider/kubernetes/crd/traefik/v1alpha1/service.go +++ b/pkg/provider/kubernetes/crd/traefik/v1alpha1/service.go @@ -44,7 +44,8 @@ type ServiceSpec struct { // load-balancer, and a list of mirrors. type Mirroring struct { LoadBalancerSpec - Mirrors []MirrorService `json:"mirrors,omitempty"` + MaxBodySize *int64 + Mirrors []MirrorService `json:"mirrors,omitempty"` } // +k8s:deepcopy-gen=true diff --git a/pkg/provider/kubernetes/crd/traefik/v1alpha1/zz_generated.deepcopy.go b/pkg/provider/kubernetes/crd/traefik/v1alpha1/zz_generated.deepcopy.go index b1d58417a..69473f773 100644 --- a/pkg/provider/kubernetes/crd/traefik/v1alpha1/zz_generated.deepcopy.go +++ b/pkg/provider/kubernetes/crd/traefik/v1alpha1/zz_generated.deepcopy.go @@ -749,6 +749,11 @@ func (in *MirrorService) DeepCopy() *MirrorService { func (in *Mirroring) DeepCopyInto(out *Mirroring) { *out = *in in.LoadBalancerSpec.DeepCopyInto(&out.LoadBalancerSpec) + if in.MaxBodySize != nil { + in, out := &in.MaxBodySize, &out.MaxBodySize + *out = new(int64) + **out = **in + } if in.Mirrors != nil { in, out := &in.Mirrors, &out.Mirrors *out = make([]MirrorService, len(*in)) diff --git a/pkg/provider/kv/kv_test.go b/pkg/provider/kv/kv_test.go index 10de5d799..3ff8feb5d 100644 --- a/pkg/provider/kv/kv_test.go +++ b/pkg/provider/kv/kv_test.go @@ -56,6 +56,7 @@ func Test_buildConfiguration(t *testing.T) { "traefik/http/services/Service01/loadBalancer/servers/0/url": "foobar", "traefik/http/services/Service01/loadBalancer/servers/1/url": "foobar", "traefik/http/services/Service02/mirroring/service": "foobar", + "traefik/http/services/Service02/mirroring/maxBodySize": "42", "traefik/http/services/Service02/mirroring/mirrors/0/name": "foobar", "traefik/http/services/Service02/mirroring/mirrors/0/percent": "42", "traefik/http/services/Service02/mirroring/mirrors/1/name": "foobar", @@ -635,7 +636,8 @@ func Test_buildConfiguration(t *testing.T) { }, "Service02": { Mirroring: &dynamic.Mirroring{ - Service: "foobar", + Service: "foobar", + MaxBodySize: func(v int64) *int64 { return &v }(42), Mirrors: []dynamic.MirrorService{ { Name: "foobar", diff --git a/pkg/server/service/loadbalancer/mirror/mirror.go b/pkg/server/service/loadbalancer/mirror/mirror.go index 60545090a..959ef5775 100644 --- a/pkg/server/service/loadbalancer/mirror/mirror.go +++ b/pkg/server/service/loadbalancer/mirror/mirror.go @@ -2,12 +2,17 @@ package mirror import ( "bufio" + "bytes" "context" "errors" + "fmt" + "io" + "io/ioutil" "net" "net/http" "sync" + "github.com/containous/traefik/v2/pkg/log" "github.com/containous/traefik/v2/pkg/middlewares/accesslog" "github.com/containous/traefik/v2/pkg/safe" ) @@ -19,16 +24,19 @@ type Mirroring struct { rw http.ResponseWriter routinePool *safe.Pool + maxBodySize int64 + lock sync.RWMutex total uint64 } // New returns a new instance of *Mirroring. -func New(handler http.Handler, pool *safe.Pool) *Mirroring { +func New(handler http.Handler, pool *safe.Pool, maxBodySize int64) *Mirroring { return &Mirroring{ routinePool: pool, handler: handler, - rw: blackholeResponseWriter{}, + rw: blackHoleResponseWriter{}, + maxBodySize: maxBodySize, } } @@ -47,41 +55,73 @@ type mirrorHandler struct { count uint64 } +func (m *Mirroring) getActiveMirrors() []http.Handler { + total := m.inc() + + var mirrors []http.Handler + for _, handler := range m.mirrorHandlers { + handler.lock.Lock() + if handler.count*100 < total*uint64(handler.percent) { + handler.count++ + handler.lock.Unlock() + mirrors = append(mirrors, handler) + } else { + handler.lock.Unlock() + } + } + return mirrors +} + func (m *Mirroring) ServeHTTP(rw http.ResponseWriter, req *http.Request) { - m.handler.ServeHTTP(rw, req) + mirrors := m.getActiveMirrors() + if len(mirrors) == 0 { + m.handler.ServeHTTP(rw, req) + return + } + + logger := log.FromContext(req.Context()) + rr, bytesRead, err := newReusableRequest(req, m.maxBodySize) + if err != nil && err != errBodyTooLarge { + http.Error(rw, http.StatusText(http.StatusInternalServerError)+ + fmt.Sprintf("error creating reusable request: %v", err), http.StatusInternalServerError) + return + } + + if err == errBodyTooLarge { + req.Body = ioutil.NopCloser(io.MultiReader(bytes.NewReader(bytesRead), req.Body)) + m.handler.ServeHTTP(rw, req) + logger.Debugf("no mirroring, request body larger than allowed size") + return + } + + m.handler.ServeHTTP(rw, rr.clone(req.Context())) select { case <-req.Context().Done(): // No mirroring if request has been canceled during main handler ServeHTTP + logger.Warn("no mirroring, request has been canceled during main handler ServeHTTP") return default: } m.routinePool.GoCtx(func(_ context.Context) { - total := m.inc() - for _, handler := range m.mirrorHandlers { - handler.lock.Lock() - if handler.count*100 < total*uint64(handler.percent) { - handler.count++ - handler.lock.Unlock() + for _, handler := range mirrors { + // prepare request, update body from buffer + r := rr.clone(req.Context()) - // In ServeHTTP, we rely on the presence of the accesslog datatable found in the - // request's context to know whether we should mutate said datatable (and - // contribute some fields to the log). In this instance, we do not want the mirrors - // mutating (i.e. changing the service name in) the logs related to the mirrored - // server. Especially since it would result in unguarded concurrent reads/writes on - // the datatable. Therefore, we reset any potential datatable key in the new - // context that we pass around. - ctx := context.WithValue(req.Context(), accesslog.DataTableKey, nil) + // In ServeHTTP, we rely on the presence of the accessLog datatable found in the request's context + // to know whether we should mutate said datatable (and contribute some fields to the log). + // In this instance, we do not want the mirrors mutating (i.e. changing the service name in) + // the logs related to the mirrored server. + // Especially since it would result in unguarded concurrent reads/writes on the datatable. + // Therefore, we reset any potential datatable key in the new context that we pass around. + ctx := context.WithValue(r.Context(), accesslog.DataTableKey, nil) - // When a request served by m.handler is successful, req.Context will be canceled, - // which would trigger a cancellation of the ongoing mirrored requests. - // Therefore, we give a new, non-cancellable context to each of the mirrored calls, - // so they can terminate by themselves. - handler.ServeHTTP(m.rw, req.WithContext(contextStopPropagation{ctx})) - } else { - handler.lock.Unlock() - } + // When a request served by m.handler is successful, req.Context will be canceled, + // which would trigger a cancellation of the ongoing mirrored requests. + // Therefore, we give a new, non-cancellable context to each of the mirrored calls, + // so they can terminate by themselves. + handler.ServeHTTP(m.rw, r.WithContext(contextStopPropagation{ctx})) } }) } @@ -95,23 +135,23 @@ func (m *Mirroring) AddMirror(handler http.Handler, percent int) error { return nil } -type blackholeResponseWriter struct{} +type blackHoleResponseWriter struct{} -func (b blackholeResponseWriter) Flush() {} +func (b blackHoleResponseWriter) Flush() {} -func (b blackholeResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { - return nil, nil, errors.New("connection on blackholeResponseWriter cannot be hijacked") +func (b blackHoleResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + return nil, nil, errors.New("connection on blackHoleResponseWriter cannot be hijacked") } -func (b blackholeResponseWriter) Header() http.Header { +func (b blackHoleResponseWriter) Header() http.Header { return http.Header{} } -func (b blackholeResponseWriter) Write(bytes []byte) (int, error) { +func (b blackHoleResponseWriter) Write(bytes []byte) (int, error) { return len(bytes), nil } -func (b blackholeResponseWriter) WriteHeader(statusCode int) {} +func (b blackHoleResponseWriter) WriteHeader(statusCode int) {} type contextStopPropagation struct { context.Context @@ -120,3 +160,65 @@ type contextStopPropagation struct { func (c contextStopPropagation) Done() <-chan struct{} { return make(chan struct{}) } + +// reusableRequest keeps in memory the body of the given request, +// so that the request can be fully cloned by each mirror. +type reusableRequest struct { + req *http.Request + body []byte +} + +var errBodyTooLarge = errors.New("request body too large") + +// if the returned error is errBodyTooLarge, newReusableRequest also returns the +// bytes that were already consumed from the request's body. +func newReusableRequest(req *http.Request, maxBodySize int64) (*reusableRequest, []byte, error) { + if req == nil { + return nil, nil, errors.New("nil input request") + } + if req.Body == nil { + return &reusableRequest{req: req}, nil, nil + } + + // unbounded body size + if maxBodySize < 0 { + body, err := ioutil.ReadAll(req.Body) + if err != nil { + return nil, nil, err + } + return &reusableRequest{ + req: req, + body: body, + }, nil, nil + } + + // we purposefully try to read _more_ than maxBodySize to detect whether + // the request body is larger than what we allow for the mirrors. + body := make([]byte, maxBodySize+1) + n, err := io.ReadFull(req.Body, body) + if err != nil && err != io.ErrUnexpectedEOF { + return nil, nil, err + } + + // we got an ErrUnexpectedEOF, which means there was less than maxBodySize data to read, + // which permits us sending also to all the mirrors later. + if err == io.ErrUnexpectedEOF { + return &reusableRequest{ + req: req, + body: body[:n], + }, nil, nil + } + + // err == nil , which means data size > maxBodySize + return nil, body[:n], errBodyTooLarge +} + +func (rr reusableRequest) clone(ctx context.Context) *http.Request { + req := rr.req.Clone(ctx) + + if rr.body != nil { + req.Body = ioutil.NopCloser(bytes.NewReader(rr.body)) + } + + return req +} diff --git a/pkg/server/service/loadbalancer/mirror/mirror_test.go b/pkg/server/service/loadbalancer/mirror/mirror_test.go index f223d72fd..f2e8da9d1 100644 --- a/pkg/server/service/loadbalancer/mirror/mirror_test.go +++ b/pkg/server/service/loadbalancer/mirror/mirror_test.go @@ -1,7 +1,9 @@ package mirror import ( + "bytes" "context" + "io/ioutil" "net/http" "net/http/httptest" "sync/atomic" @@ -11,13 +13,15 @@ import ( "github.com/stretchr/testify/assert" ) +const defaultMaxBodySize int64 = -1 + func TestMirroringOn100(t *testing.T) { var countMirror1, countMirror2 int32 handler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { rw.WriteHeader(http.StatusOK) }) pool := safe.NewPool(context.Background()) - mirror := New(handler, pool) + mirror := New(handler, pool, defaultMaxBodySize) err := mirror.AddMirror(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { atomic.AddInt32(&countMirror1, 1) }), 10) @@ -46,7 +50,7 @@ func TestMirroringOn10(t *testing.T) { rw.WriteHeader(http.StatusOK) }) pool := safe.NewPool(context.Background()) - mirror := New(handler, pool) + mirror := New(handler, pool, defaultMaxBodySize) err := mirror.AddMirror(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { atomic.AddInt32(&countMirror1, 1) }), 10) @@ -70,7 +74,7 @@ func TestMirroringOn10(t *testing.T) { } func TestInvalidPercent(t *testing.T) { - mirror := New(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {}), safe.NewPool(context.Background())) + mirror := New(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {}), safe.NewPool(context.Background()), defaultMaxBodySize) err := mirror.AddMirror(nil, -1) assert.Error(t, err) @@ -89,7 +93,7 @@ func TestHijack(t *testing.T) { rw.WriteHeader(http.StatusOK) }) pool := safe.NewPool(context.Background()) - mirror := New(handler, pool) + mirror := New(handler, pool, defaultMaxBodySize) var mirrorRequest bool err := mirror.AddMirror(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { @@ -113,7 +117,7 @@ func TestFlush(t *testing.T) { rw.WriteHeader(http.StatusOK) }) pool := safe.NewPool(context.Background()) - mirror := New(handler, pool) + mirror := New(handler, pool, defaultMaxBodySize) var mirrorRequest bool err := mirror.AddMirror(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { @@ -131,3 +135,121 @@ func TestFlush(t *testing.T) { pool.Stop() assert.Equal(t, true, mirrorRequest) } + +func TestMirroringWithBody(t *testing.T) { + const numMirrors = 10 + + var ( + countMirror int32 + body = []byte(`body`) + ) + + pool := safe.NewPool(context.Background()) + + handler := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + assert.NotNil(t, r.Body) + bb, err := ioutil.ReadAll(r.Body) + assert.NoError(t, err) + assert.Equal(t, body, bb) + rw.WriteHeader(http.StatusOK) + }) + + mirror := New(handler, pool, defaultMaxBodySize) + + for i := 0; i < numMirrors; i++ { + err := mirror.AddMirror(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + assert.NotNil(t, r.Body) + bb, err := ioutil.ReadAll(r.Body) + assert.NoError(t, err) + assert.Equal(t, body, bb) + atomic.AddInt32(&countMirror, 1) + }), 100) + assert.NoError(t, err) + } + + req := httptest.NewRequest(http.MethodPost, "/", bytes.NewBuffer(body)) + + mirror.ServeHTTP(httptest.NewRecorder(), req) + + pool.Stop() + + val := atomic.LoadInt32(&countMirror) + assert.Equal(t, numMirrors, int(val)) +} + +func TestCloneRequest(t *testing.T) { + t.Run("http request body is nil", func(t *testing.T) { + req, err := http.NewRequest(http.MethodPost, "/", nil) + assert.NoError(t, err) + + ctx := req.Context() + rr, _, err := newReusableRequest(req, defaultMaxBodySize) + assert.NoError(t, err) + + // first call + cloned := rr.clone(ctx) + assert.Equal(t, cloned, req) + assert.Nil(t, cloned.Body) + + // second call + cloned = rr.clone(ctx) + assert.Equal(t, cloned, req) + assert.Nil(t, cloned.Body) + }) + + t.Run("http request body is not nil", func(t *testing.T) { + bb := []byte(`¯\_(ツ)_/¯`) + contentLength := len(bb) + + buf := bytes.NewBuffer(bb) + req, err := http.NewRequest(http.MethodPost, "/", buf) + assert.NoError(t, err) + + ctx := req.Context() + req.ContentLength = int64(contentLength) + + rr, _, err := newReusableRequest(req, defaultMaxBodySize) + assert.NoError(t, err) + + // first call + cloned := rr.clone(ctx) + body, err := ioutil.ReadAll(cloned.Body) + assert.NoError(t, err) + assert.Equal(t, bb, body) + + // second call + cloned = rr.clone(ctx) + body, err = ioutil.ReadAll(cloned.Body) + assert.NoError(t, err) + assert.Equal(t, bb, body) + }) + + t.Run("failed case", func(t *testing.T) { + bb := []byte(`1234567890`) + buf := bytes.NewBuffer(bb) + + req, err := http.NewRequest(http.MethodPost, "/", buf) + assert.NoError(t, err) + + _, expectedBytes, err := newReusableRequest(req, 2) + assert.Error(t, err) + assert.Equal(t, bb[:3], expectedBytes) + }) + + t.Run("valid case with maxBodySize", func(t *testing.T) { + bb := []byte(`1234567890`) + buf := bytes.NewBuffer(bb) + + req, err := http.NewRequest(http.MethodPost, "/", buf) + assert.NoError(t, err) + + _, expectedBytes, err := newReusableRequest(req, 20) + assert.NoError(t, err) + assert.Nil(t, expectedBytes) + }) + + t.Run("no request given", func(t *testing.T) { + _, _, err := newReusableRequest(nil, defaultMaxBodySize) + assert.Error(t, err) + }) +} diff --git a/pkg/server/service/service.go b/pkg/server/service/service.go index 3182234ae..ff5c510f4 100644 --- a/pkg/server/service/service.go +++ b/pkg/server/service/service.go @@ -33,6 +33,8 @@ const ( defaultHealthCheckTimeout = 5 * time.Second ) +const defaultMaxBodySize int64 = -1 + // NewManager creates a new Manager func NewManager(configs map[string]*runtime.ServiceInfo, defaultRoundTripper http.RoundTripper, metricsRegistry metrics.Registry, routinePool *safe.Pool) *Manager { return &Manager{ @@ -123,7 +125,11 @@ func (m *Manager) getMirrorServiceHandler(ctx context.Context, config *dynamic.M return nil, err } - handler := mirror.New(serviceHandler, m.routinePool) + maxBodySize := defaultMaxBodySize + if config.MaxBodySize != nil { + maxBodySize = *config.MaxBodySize + } + handler := mirror.New(serviceHandler, m.routinePool, maxBodySize) for _, mirrorConfig := range config.Mirrors { mirrorHandler, err := m.BuildHTTP(ctx, mirrorConfig.Name, responseModifier) if err != nil {