diff --git a/router.go b/router.go index d42054b..d1514d0 100644 --- a/router.go +++ b/router.go @@ -66,35 +66,38 @@ func NewAPIGRouter(r *events.APIGatewayProxyRequest, svcprefix string) *APIGRout } // Get creates a new get endpoint. -func (r *APIGRouter) Get(route string, handler APIGHandler) { - r.addEndpoint(get, route, handler) +func (r *APIGRouter) Get(route string, handlers ...APIGHandler) { + r.addEndpoint(get, route, handlers) } // Post creates a new post endpoint. -func (r *APIGRouter) Post(route string, handler APIGHandler) { - r.addEndpoint(post, route, handler) +func (r *APIGRouter) Post(route string, handlers ...APIGHandler) { + r.addEndpoint(post, route, handlers) } // Put creates a new put endpoint. -func (r *APIGRouter) Put(route string, handler APIGHandler) { - r.addEndpoint(put, route, handler) +func (r *APIGRouter) Put(route string, handlers ...APIGHandler) { + r.addEndpoint(put, route, handlers) } // Patch creates a new patch endpoint -func (r *APIGRouter) Patch(route string, handler APIGHandler) { - r.addEndpoint(patch, route, handler) +func (r *APIGRouter) Patch(route string, handlers ...APIGHandler) { + r.addEndpoint(patch, route, handlers) } // Delete creates a new delete endpoint. -func (r *APIGRouter) Delete(route string, handler APIGHandler) { - r.addEndpoint(delete, route, handler) +func (r *APIGRouter) Delete(route string, handlers ...APIGHandler) { + r.addEndpoint(delete, route, handlers) } // Respond returns an APIGatewayProxyResponse to respond to the lambda request. func (r *APIGRouter) Respond() events.APIGatewayProxyResponse { var ( - handlerInterface interface{} - ok bool + handlersInterface interface{} + ok bool + status int + respbody []byte + err error endpointTree = r.endpoints[r.request.HTTPMethod] path = strings.TrimPrefix(r.request.Path, "/"+r.svcprefix) @@ -109,7 +112,7 @@ func (r *APIGRouter) Respond() events.APIGatewayProxyResponse { } } - if handlerInterface, ok = endpointTree.Get(path); !ok { + if handlersInterface, ok = endpointTree.Get(path); !ok { respbody, _ := json.Marshal(map[string]string{"error": "no route matching path found"}) response.StatusCode = http.StatusNotFound @@ -117,32 +120,34 @@ func (r *APIGRouter) Respond() events.APIGatewayProxyResponse { return response } - handler := handlerInterface.(APIGHandler) + handlers := handlersInterface.([]APIGHandler) - req := &APIGRequest{ - Path: r.request.PathParameters, - QryStr: r.request.QueryStringParameters, - Request: r.request, - } - if r.request.RequestContext.Authorizer["claims"] != nil { - req.Claims = r.request.RequestContext.Authorizer["claims"].(map[string]interface{}) - } - res := &APIGResponse{} - - handler(req, res) - status, respbody, err := res.deconstruct() - - if err != nil { - respbody, _ := json.Marshal(map[string]string{"error": err.Error()}) - if strings.Contains(err.Error(), "record not found") { - status = 204 - } else if status < 400 { - status = 400 + for _, handler := range handlers { + req := &APIGRequest{ + Path: r.request.PathParameters, + QryStr: r.request.QueryStringParameters, + Request: r.request, } + if r.request.RequestContext.Authorizer["claims"] != nil { + req.Claims = r.request.RequestContext.Authorizer["claims"].(map[string]interface{}) + } + res := &APIGResponse{} - response.StatusCode = status - response.Body = string(respbody) - return response + handler(req, res) + status, respbody, err = res.deconstruct() + + if err != nil { + respbody, _ := json.Marshal(map[string]string{"error": err.Error()}) + if strings.Contains(err.Error(), "record not found") { + status = 204 + } else if status < 400 { + status = 400 + } + + response.StatusCode = status + response.Body = string(respbody) + return response + } } response.StatusCode = status @@ -161,8 +166,8 @@ func (res *APIGResponse) deconstruct() (int, []byte, error) { return res.Status, res.Body, res.Err } -func (r *APIGRouter) addEndpoint(method string, route string, handler APIGHandler) { - if _, overwrite := r.endpoints[method].Insert(route, handler); overwrite { +func (r *APIGRouter) addEndpoint(method string, route string, handlers []APIGHandler) { + if _, overwrite := r.endpoints[method].Insert(route, handlers); overwrite { panic("endpoint already existent") } diff --git a/router_test.go b/router_test.go index b681803..723a7c5 100644 --- a/router_test.go +++ b/router_test.go @@ -20,7 +20,6 @@ func TestRouterSpec(t *testing.T) { res.Status = http.StatusOK res.Body = []byte("hello") res.Err = nil - } Convey("And a Get handler expecting the pattern /orders/filter/by_user/{id} is defined", func() { @@ -128,15 +127,24 @@ func TestRouterSpec(t *testing.T) { }) Convey("When the handler func does return a status < 400", func() { - hdlrfunc := func(req *APIGRequest, res *APIGResponse) { + middlefunc1 := func(req *APIGRequest, res *APIGResponse) { + res.Status = http.StatusOK + res.Body = []byte("hello") + res.Err = nil + } + middlefunc2 := func(req *APIGRequest, res *APIGResponse) { res.Status = http.StatusOK res.Body = []byte("hello") res.Err = errors.New("bad request") - + } + hdlrfunc := func(req *APIGRequest, res *APIGResponse) { + res.Status = http.StatusOK + res.Body = []byte("hello") + res.Err = nil } Convey("And a Get handler expecting the pattern /orders/filter/by_user/{id} is defined", func() { - rtr.Get("/orders/filter/by_user/{id}", hdlrfunc) + rtr.Get("/orders/filter/by_user/{id}", middlefunc1, middlefunc2, hdlrfunc) Convey("And the request matches the pattern and the path params are filled", func() { request.HTTPMethod = http.MethodGet