diff --git a/router.go b/router.go index b2e1a08..67b4671 100644 --- a/router.go +++ b/router.go @@ -176,11 +176,11 @@ func (r *APIGRouter) Respond() events.APIGatewayProxyResponse { functions := routeInterface.(routeFunctions) - for _, m := range functions.middleware { - functions.handler = m(functions.handler) + for i := len(functions.middleware) - 1; i >= 0; i-- { + functions.handler = functions.middleware[i](functions.handler) } - for _, m := range r.middleware { - functions.handler = m(functions.handler) + for i := len(r.middleware) - 1; i >= 0; i-- { + functions.handler = r.middleware[i](functions.handler) } functions.handler(ctx) @@ -190,6 +190,29 @@ func (r *APIGRouter) Respond() events.APIGatewayProxyResponse { return response } +// ConvertHandler converts a pre-existing APIGHandler to an APIGMiddleware type. +// The before boolean determines whether the converted handler runs before and +// causes a return on error, or after and does not return if the first handler fails. +func ConvertHandler(handler APIGHandler, before bool) APIGMiddleware { + if before { + return func(next APIGHandler) APIGHandler { + return func(ctx *APIGContext) { + handler(ctx) + if ctx.Err != nil { + return + } + next(ctx) + } + } + } + return func(first APIGHandler) APIGHandler { + return func(ctx *APIGContext) { + first(ctx) + handler(ctx) + } + } +} + // NOTE: Begin helper functions. type routeFunctions struct { handler APIGHandler diff --git a/router_test.go b/router_test.go index e75227c..f0658ba 100644 --- a/router_test.go +++ b/router_test.go @@ -15,16 +15,13 @@ func TestRouterSpec(t *testing.T) { Convey("Given an instantiated router with an error reporting middleware, headers, and prefix", t, func() { request := events.APIGatewayProxyRequest{} - errorReporter := func(handler APIGHandler) APIGHandler { - return func(ctx *APIGContext) { - handler(ctx) - if ctx.Err != nil { - ctx.Body, _ = json.Marshal(map[string]string{"error": ctx.Err.Error()}) - if strings.Contains(ctx.Err.Error(), "record not found") { - ctx.Status = 204 - } else if ctx.Status != 204 && ctx.Status < 400 { - ctx.Status = 400 - } + errorReporter := func(ctx *APIGContext) { + if ctx.Err != nil { + ctx.Body, _ = json.Marshal(map[string]string{"error": ctx.Err.Error()}) + if strings.Contains(ctx.Err.Error(), "record not found") { + ctx.Status = 204 + } else if ctx.Status != 204 && ctx.Status < 400 { + ctx.Status = 400 } } } @@ -36,7 +33,7 @@ func TestRouterSpec(t *testing.T) { "Access-Control-Allow-Credentials": "true", }, Middleware: []APIGMiddleware{ - errorReporter, + ConvertHandler(errorReporter, false), }, }) @@ -162,20 +159,36 @@ func TestRouterSpec(t *testing.T) { }) Convey("When the handler func does return a status < 400", func() { - middlefunc1 := func(handler APIGHandler) APIGHandler { + middlefunc1 := func(ctx *APIGContext) { + ctx.Status = http.StatusOK + } + middlefunc2 := func(handler APIGHandler) APIGHandler { return func(ctx *APIGContext) { ctx.Status = http.StatusOK ctx.Err = errors.New("bad request") + if ctx.Err != nil { + return + } handler(ctx) + } } + middlefunc3 := func(ctx *APIGContext) { + ctx.Err = errors.New("bad request") + } hdlrfunc := func(ctx *APIGContext) { ctx.Status = http.StatusOK ctx.Body = []byte("hello") } Convey("And a Get handler expecting the pattern /listings/{id}/state/{event} is defined", func() { - rtr.Get("/listings/{id}/state/{event}", hdlrfunc, middlefunc1) + rtr.Get( + "/listings/{id}/state/{event}", + hdlrfunc, + ConvertHandler(middlefunc1, true), + ConvertHandler(middlefunc3, true), + middlefunc2, + ) Convey("And the request matches the pattern and the path params are filled", func() { request.HTTPMethod = http.MethodGet