traefik/server/server_entrypoint_tcp_test.go
Julien Salleyron 4a68d29ce2 Add a new protocol
Co-authored-by: Gérald Croës <gerald@containo.us>
2019-03-14 09:30:04 +01:00

142 lines
3.6 KiB
Go

package server
import (
"bufio"
"context"
"net"
"net/http"
"testing"
"time"
"github.com/containous/traefik/config/static"
"github.com/containous/traefik/tcp"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestShutdownHTTP(t *testing.T) {
entryPoint, err := NewTCPEntryPoint(context.Background(), &static.EntryPoint{
Address: ":0",
Transport: &static.EntryPointsTransport{
LifeCycle: &static.LifeCycle{
RequestAcceptGraceTimeout: 0,
GraceTimeOut: 5000000000,
},
},
ForwardedHeaders: &static.ForwardedHeaders{},
})
require.NoError(t, err)
go entryPoint.startTCP(context.Background())
router := &tcp.Router{}
router.HTTPHandler(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
time.Sleep(time.Second * 1)
rw.WriteHeader(http.StatusOK)
}))
entryPoint.switchRouter(router)
conn, err := net.Dial("tcp", entryPoint.listener.Addr().String())
require.NoError(t, err)
go entryPoint.Shutdown(context.Background())
request, err := http.NewRequest("GET", "http://127.0.0.1:8082", nil)
require.NoError(t, err)
err = request.Write(conn)
require.NoError(t, err)
resp, err := http.ReadResponse(bufio.NewReader(conn), request)
require.NoError(t, err)
assert.Equal(t, resp.StatusCode, http.StatusOK)
}
func TestShutdownHTTPHijacked(t *testing.T) {
entryPoint, err := NewTCPEntryPoint(context.Background(), &static.EntryPoint{
Address: ":0",
Transport: &static.EntryPointsTransport{
LifeCycle: &static.LifeCycle{
RequestAcceptGraceTimeout: 0,
GraceTimeOut: 5000000000,
},
},
ForwardedHeaders: &static.ForwardedHeaders{},
})
require.NoError(t, err)
go entryPoint.startTCP(context.Background())
router := &tcp.Router{}
router.HTTPHandler(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
conn, _, err := rw.(http.Hijacker).Hijack()
require.NoError(t, err)
time.Sleep(time.Second * 1)
resp := http.Response{StatusCode: http.StatusOK}
err = resp.Write(conn)
require.NoError(t, err)
}))
entryPoint.switchRouter(router)
conn, err := net.Dial("tcp", entryPoint.listener.Addr().String())
require.NoError(t, err)
go entryPoint.Shutdown(context.Background())
request, err := http.NewRequest("GET", "http://127.0.0.1:8082", nil)
require.NoError(t, err)
err = request.Write(conn)
require.NoError(t, err)
resp, err := http.ReadResponse(bufio.NewReader(conn), request)
require.NoError(t, err)
assert.Equal(t, resp.StatusCode, http.StatusOK)
}
func TestShutdownTCPConn(t *testing.T) {
entryPoint, err := NewTCPEntryPoint(context.Background(), &static.EntryPoint{
Address: ":0",
Transport: &static.EntryPointsTransport{
LifeCycle: &static.LifeCycle{
RequestAcceptGraceTimeout: 0,
GraceTimeOut: 5000000000,
},
},
ForwardedHeaders: &static.ForwardedHeaders{},
})
require.NoError(t, err)
go entryPoint.startTCP(context.Background())
router := &tcp.Router{}
router.AddCatchAllNoTLS(tcp.HandlerFunc(func(conn net.Conn) {
_, err := http.ReadRequest(bufio.NewReader(conn))
require.NoError(t, err)
time.Sleep(time.Second * 1)
resp := http.Response{StatusCode: http.StatusOK}
err = resp.Write(conn)
require.NoError(t, err)
}))
entryPoint.switchRouter(router)
conn, err := net.Dial("tcp", entryPoint.listener.Addr().String())
require.NoError(t, err)
go entryPoint.Shutdown(context.Background())
request, err := http.NewRequest("GET", "http://127.0.0.1:8082", nil)
require.NoError(t, err)
err = request.Write(conn)
require.NoError(t, err)
resp, err := http.ReadResponse(bufio.NewReader(conn), request)
require.NoError(t, err)
assert.Equal(t, resp.StatusCode, http.StatusOK)
}